# Saliency Maps

A saliency map is a visualization that assigns an importance score to each part of an input. For an image model, the saliency map usually assigns a score to each pixel or image region. For a text model, it may assign a score to each token. The goal is to estimate which parts of the input most influenced the model’s prediction.

Saliency methods are often used for model inspection. They can reveal whether a classifier looks at the object, the background, a watermark, a border artifact, or another unintended feature. They do not prove that a model reasons correctly. They are diagnostic tools, not complete explanations.

### Basic Idea

Let a model produce a score for class $c$:

$$
s_c = f_\theta(x)_c.
$$

A saliency method asks how sensitive this score is to each input component. The simplest form uses the gradient of the class score with respect to the input:

$$
S(x) = \left| \nabla_x s_c \right|.
$$

Each entry of $S(x)$ measures how much a small change in the corresponding input component would change the class score.

For images, $x$ may have shape

```python
[B, C, H, W]
```

and the saliency tensor has the same shape. It is often reduced across color channels to produce a heatmap of shape

```python
[B, H, W]
```

### Vanilla Gradient Saliency

Vanilla gradient saliency is the simplest saliency method.

Given an input image $x$, a target class $c$, and a model $f_\theta$, compute the gradient of the class score with respect to the input:

$$
S = \left|\frac{\partial f_\theta(x)_c}{\partial x}\right|.
$$

In PyTorch:

```python
import torch

def vanilla_saliency(model, x, target_class=None):
    model.eval()

    x = x.clone().detach().requires_grad_(True)
    logits = model(x)

    if target_class is None:
        target_class = logits.argmax(dim=1)

    scores = logits.gather(1, target_class.view(-1, 1)).sum()

    model.zero_grad(set_to_none=True)
    scores.backward()

    saliency = x.grad.abs()

    # Reduce RGB channels to one heatmap per image.
    saliency = saliency.max(dim=1).values

    return saliency.detach()
```

This function returns a heatmap with shape `[B, H, W]`.

Vanilla saliency is cheap because it requires only one backward pass. Its weakness is noise. Gradients may highlight small high-frequency patterns rather than semantically meaningful regions.

### Class Scores Versus Loss Gradients

A saliency map should be clear about which quantity is being differentiated.

One option is to differentiate the class score:

$$
\nabla_x f_\theta(x)_c.
$$

This asks: which input components increase or decrease the score for class $c$?

Another option is to differentiate the loss:

$$
\nabla_x L(f_\theta(x), y).
$$

This asks: which input components would most change the training objective?

For interpretation, class-score gradients are usually more direct. For adversarial attacks, loss gradients are usually more useful. The same gradient machinery is used in both cases, but the objective is different.

### Positive and Negative Evidence

Taking the absolute value of the gradient removes the sign. This shows sensitivity, but it loses direction.

A signed gradient keeps direction:

$$
S = \frac{\partial f_\theta(x)_c}{\partial x}.
$$

Positive values indicate that increasing the input component would increase the class score. Negative values indicate that increasing it would decrease the class score.

For images, signed saliency can be harder to interpret because pixels have channels, normalization, and nonlinear preprocessing. For tabular data, signed gradients may be more meaningful.

### SmoothGrad

SmoothGrad reduces noise by averaging saliency maps over noisy copies of the same input.

Let $\eta_i$ be random noise. SmoothGrad computes:

$$
S_{\text{smooth}}(x) =
\frac{1}{n}
\sum_{i=1}^n
\left|
\nabla_x f_\theta(x + \eta_i)_c
\right|.
$$

The intuition is that meaningful regions should remain important under small input noise, while random gradient noise should average out.

In PyTorch:

```python
def smoothgrad_saliency(model, x, target_class=None, samples=32, noise_std=0.1):
    model.eval()

    saliency_sum = torch.zeros(
        x.shape[0], x.shape[2], x.shape[3],
        device=x.device,
    )

    for _ in range(samples):
        noise = torch.randn_like(x) * noise_std
        x_noisy = torch.clamp(x + noise, 0.0, 1.0)

        saliency = vanilla_saliency(model, x_noisy, target_class)
        saliency_sum += saliency

    return saliency_sum / samples
```

SmoothGrad is more expensive than vanilla saliency because it requires multiple backward passes. It often produces cleaner visualizations.

### Integrated Gradients

Integrated gradients compare the input with a baseline input. The baseline is usually a black image, zero vector, padding token, or another neutral reference.

For input $x$, baseline $x'$, and class score $F(x)$, integrated gradients are defined as:

$$
\operatorname{IG}_i(x) =
(x_i - x'_i)
\int_0^1
\frac{\partial F(x' + \alpha(x - x'))}{\partial x_i}
d\alpha.
$$

The method accumulates gradients along the straight path from the baseline to the real input.

In practice, the integral is approximated with a finite sum:

$$
\operatorname{IG}_i(x)
\approx
(x_i - x'_i)
\frac{1}{m}
\sum_{k=1}^m
\frac{\partial F(x' + \frac{k}{m}(x - x'))}{\partial x_i}.
$$

PyTorch implementation:

```python
def integrated_gradients(model, x, target_class=None, baseline=None, steps=50):
    model.eval()

    if baseline is None:
        baseline = torch.zeros_like(x)

    if target_class is None:
        with torch.no_grad():
            logits = model(x)
            target_class = logits.argmax(dim=1)

    total_grad = torch.zeros_like(x)

    for k in range(1, steps + 1):
        alpha = k / steps
        x_interp = baseline + alpha * (x - baseline)
        x_interp = x_interp.clone().detach().requires_grad_(True)

        logits = model(x_interp)
        scores = logits.gather(1, target_class.view(-1, 1)).sum()

        model.zero_grad(set_to_none=True)
        scores.backward()

        total_grad += x_interp.grad

    avg_grad = total_grad / steps
    attribution = (x - baseline) * avg_grad

    # For images, reduce channels.
    heatmap = attribution.abs().sum(dim=1)

    return heatmap.detach()
```

Integrated gradients are often more stable than vanilla gradients. The choice of baseline matters. A poor baseline can produce misleading attributions.

### Grad-CAM

Gradient-weighted Class Activation Mapping, or Grad-CAM, is commonly used for convolutional networks. It produces a coarse heatmap over spatial regions of a feature map.

Let $A^k$ be feature map $k$ in a convolutional layer. Let $s_c$ be the score for class $c$. Grad-CAM computes channel weights:

$$
\alpha_k^c =
\frac{1}{Z}
\sum_i
\sum_j
\frac{\partial s_c}{\partial A_{ij}^k}.
$$

Then it forms a weighted sum of feature maps:

$$
L_{\text{Grad-CAM}}^c =
\operatorname{ReLU}
\left(
\sum_k
\alpha_k^c A^k
\right).
$$

The result is resized to the input image size.

Grad-CAM is useful because it highlights larger semantic regions rather than individual pixels. It is less precise spatially, but often easier to interpret.

A minimal PyTorch implementation requires hooks:

```python
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model.eval()
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None

        self.forward_handle = target_layer.register_forward_hook(self._save_activation)
        self.backward_handle = target_layer.register_full_backward_hook(self._save_gradient)

    def _save_activation(self, module, inputs, output):
        self.activations = output

    def _save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def remove_hooks(self):
        self.forward_handle.remove()
        self.backward_handle.remove()

    def __call__(self, x, target_class=None):
        logits = self.model(x)

        if target_class is None:
            target_class = logits.argmax(dim=1)

        score = logits.gather(1, target_class.view(-1, 1)).sum()

        self.model.zero_grad(set_to_none=True)
        score.backward()

        weights = self.gradients.mean(dim=(2, 3), keepdim=True)
        cam = (weights * self.activations).sum(dim=1)
        cam = torch.relu(cam)

        cam_min = cam.amin(dim=(1, 2), keepdim=True)
        cam_max = cam.amax(dim=(1, 2), keepdim=True)
        cam = (cam - cam_min) / (cam_max - cam_min + 1e-8)

        return cam.detach()
```

For a ResNet-like model, the target layer is often the last convolutional block.

### Occlusion Sensitivity

Occlusion sensitivity is a perturbation-based saliency method. Instead of using gradients, it masks parts of the input and observes how the model output changes.

For an image, we slide a patch over the image and replace that patch with a baseline value. If masking a region greatly reduces the score for class $c$, that region is considered important.

The method is easy to understand but computationally expensive. It requires many forward passes.

```python
def occlusion_saliency(model, x, target_class=None, patch_size=16, stride=8):
    model.eval()

    with torch.no_grad():
        logits = model(x)
        if target_class is None:
            target_class = logits.argmax(dim=1)
        base_scores = logits.gather(1, target_class.view(-1, 1)).squeeze(1)

    B, C, H, W = x.shape
    heatmap = torch.zeros(B, H, W, device=x.device)
    counts = torch.zeros(B, H, W, device=x.device)

    for top in range(0, H, stride):
        for left in range(0, W, stride):
            bottom = min(top + patch_size, H)
            right = min(left + patch_size, W)

            x_occ = x.clone()
            x_occ[:, :, top:bottom, left:right] = 0.0

            with torch.no_grad():
                logits_occ = model(x_occ)
                scores_occ = logits_occ.gather(1, target_class.view(-1, 1)).squeeze(1)

            drop = base_scores - scores_occ
            heatmap[:, top:bottom, left:right] += drop.view(B, 1, 1)
            counts[:, top:bottom, left:right] += 1

    return heatmap / counts.clamp_min(1)
```

Occlusion methods can capture nonlinear effects better than local gradients. Their cost grows quickly with image size and patch resolution.

### Saliency for Text Models

For text, the input tokens are discrete. We cannot usually take meaningful gradients with respect to token IDs. Instead, saliency is computed with respect to token embeddings.

Suppose token IDs have shape

```python
[B, T]
```

and embeddings have shape

```python
[B, T, D]
```

A saliency score per token can be computed by differentiating the class score with respect to the embedding vectors, then reducing over the embedding dimension:

$$
S_{bt} =
\left\|
\frac{\partial s_c}{\partial e_{bt}}
\right\|_2.
$$

This gives one score per token.

The exact implementation depends on the model architecture. In transformer libraries, it is often easier to use embedding hooks or pass `inputs_embeds` directly if the model supports it.

### Failure Modes of Saliency Maps

Saliency maps are easy to misuse. They are visual and intuitive, but their meaning is limited.

Common failure modes include:

| Failure mode | Description |
|---|---|
| Gradient noise | Heatmaps may emphasize high-frequency artifacts |
| Saturation | Gradients may be small even for important features |
| Baseline dependence | Integrated gradients depend on reference input |
| Layer dependence | Grad-CAM depends on the chosen layer |
| Visual overinterpretation | A plausible heatmap may not imply correct reasoning |
| Model insensitivity | Some saliency methods change little across different models |

A saliency map answers a narrow question about sensitivity or attribution under a chosen method. It does not reveal the full causal mechanism of the model.

### Sanity Checks

A saliency method should pass basic sanity checks.

One test is parameter randomization. If the model weights are randomized, the saliency map should change substantially. If the heatmap remains similar, the method may mostly reflect the input image rather than the learned model.

Another test is label randomization. A model trained on random labels should produce different explanations from a model trained on real labels.

A third test is input perturbation. Removing highly salient regions should reduce the target score more than removing low-saliency regions. If it does not, the saliency map may not be faithful.

### Practical Use

Saliency maps are useful when treated as one signal among several. They can help find dataset artifacts, preprocessing bugs, shortcut features, and suspicious model behavior.

For image classifiers, inspect saliency maps across many examples, not only cherry-picked cases. Compare correct and incorrect predictions. Compare clean and shifted data. Compare different model architectures.

For text models, token saliency can show whether a classifier relies on names, punctuation, demographic tokens, prompt artifacts, or label words. This is especially useful for auditing spurious correlations.

For safety-critical systems, saliency should support investigation, not certification. A model with plausible saliency can still fail under distribution shift, adversarial attack, or causal intervention.

### Summary

Saliency maps assign importance scores to input components. Vanilla gradients measure local sensitivity. SmoothGrad averages noisy gradients. Integrated gradients attribute predictions relative to a baseline. Grad-CAM highlights spatial regions using convolutional feature maps. Occlusion sensitivity measures the effect of masking input regions.

These methods are practical and widely used, but they are approximate. Their results depend on the chosen objective, layer, baseline, and visualization method. A saliency map should be read as a diagnostic artifact, not as proof that the model understands the task.

