# Data Augmentation Strategies

Data augmentation creates modified versions of training examples without changing their labels. In image classification, common augmentations include random crops, flips, rotations, color changes, blur, noise, erasing, Mixup, and CutMix.

The goal is to make the model less sensitive to irrelevant variation. A classifier should recognize a cat whether the image is slightly shifted, brighter, darker, cropped, or photographed from a different angle. Augmentation encodes these assumptions into the training process.

### Why Augmentation Helps

A neural network can memorize a small training set. It may learn exact backgrounds, lighting conditions, camera artifacts, or image layouts rather than the class concept. Augmentation reduces this risk by showing the model many plausible variants of the same example.

Let $x$ be an image and $y$ its label. An augmentation is a transformation

$$
\tilde{x} = T(x),
$$

where $T$ is sampled from a family of label-preserving transformations.

The model is then trained on

$$
(\tilde{x}, y)
$$

instead of only

$$
(x, y).
$$

The training objective becomes an expectation over both data and transformations:

$$
\mathbb{E}_{(x,y)}\mathbb{E}_{T}\left[L(f_\theta(T(x)), y)\right].
$$

This encourages the model to give stable predictions under transformations that should not change the label.

### Label-Preserving Transformations

An augmentation is valid only when it preserves the label. This depends on the task.

For ordinary object classification, horizontal flipping is often safe. For digit recognition, horizontal flipping may turn a meaningful digit into an invalid or different symbol. For medical images, flipping may change anatomical laterality. For satellite images, rotation may be acceptable. For traffic signs, strong rotation may create unrealistic examples.

| Augmentation | Usually safe for | Risky for |
|---|---|---|
| Horizontal flip | Natural images | Text, digits, medical laterality |
| Vertical flip | Satellite, microscopy | Natural object photos |
| Random crop | Object-centered images | Small objects near image edge |
| Rotation | Satellite, microscopy | Upright objects, text |
| Color jitter | Natural images | Color-coded classes |
| Blur | Robustness to focus variation | Fine-grained texture tasks |
| Erasing | Occlusion robustness | Tiny objects |
| Mixup | General classification | Localization-sensitive labels |
| CutMix | Object classification | Images with multiple small objects |

The correct augmentation policy follows the semantics of the data, not a fixed recipe.

### Basic PyTorch Transforms

Torchvision provides common image transforms.

```python
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])
```

The validation transform should usually be deterministic:

```python
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])
```

Training uses randomness. Validation should measure model performance under a fixed preprocessing protocol.

### Random Resized Crop

Random resized crop is one of the most important augmentations for image classification. It selects a random region of the image, resizes it to the model input size, and uses it as the training example.

```python
transforms.RandomResizedCrop(
    size=224,
    scale=(0.08, 1.0),
    ratio=(3 / 4, 4 / 3),
)
```

The `scale` controls the area of the crop relative to the original image. The `ratio` controls the aspect ratio.

A wide scale range gives strong augmentation. A narrower range gives more conservative augmentation.

```python
# Conservative
transforms.RandomResizedCrop(224, scale=(0.7, 1.0))

# Strong
transforms.RandomResizedCrop(224, scale=(0.08, 1.0))
```

Strong crops are useful for large datasets and robust models. For small objects, aggressive cropping can remove the object entirely and create mislabeled examples.

### Flips and Rotations

Horizontal flipping is common for natural images:

```python
transforms.RandomHorizontalFlip(p=0.5)
```

Vertical flipping is more domain-specific:

```python
transforms.RandomVerticalFlip(p=0.5)
```

Small rotations can improve robustness to camera angle:

```python
transforms.RandomRotation(degrees=10)
```

For domains where orientation has no semantic meaning, larger rotations may be acceptable:

```python
transforms.RandomRotation(degrees=180)
```

The same operation may be correct in one domain and wrong in another. Augmentation policy should be treated as part of the dataset definition.

### Color and Lighting Augmentation

Color jitter changes brightness, contrast, saturation, and hue.

```python
transforms.ColorJitter(
    brightness=0.2,
    contrast=0.2,
    saturation=0.2,
    hue=0.05,
)
```

This helps when lighting and camera conditions vary. It may hurt when color is part of the label. For example, color changes may be unsafe for classifying plant disease severity, traffic lights, or chemical assay images.

A moderate image classification transform may look like this:

```python
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.05,
    ),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])
```

### Blur, Noise, and Sharpness

Blur simulates out-of-focus images or motion blur.

```python
transforms.GaussianBlur(
    kernel_size=3,
    sigma=(0.1, 2.0),
)
```

Random sharpness adjustment changes edge contrast:

```python
transforms.RandomAdjustSharpness(
    sharpness_factor=2,
    p=0.3,
)
```

Noise can be added with a custom transform:

```python
import torch

class AddGaussianNoise:
    def __init__(self, std=0.05):
        self.std = std

    def __call__(self, x):
        noise = torch.randn_like(x) * self.std
        return torch.clamp(x + noise, 0.0, 1.0)
```

Noise should usually be applied after `ToTensor()` and before normalization:

```python
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    AddGaussianNoise(std=0.03),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])
```

### Random Erasing

Random erasing removes a random rectangle from the image tensor. It encourages robustness to occlusion.

```python
transforms.RandomErasing(
    p=0.25,
    scale=(0.02, 0.2),
    ratio=(0.3, 3.3),
)
```

`RandomErasing` operates on tensors, so it should appear after `ToTensor()` and usually after normalization.

```python
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    transforms.RandomErasing(p=0.25),
])
```

Random erasing can hurt when the object is small. If the erased region covers the class evidence, the label may become unreliable.

### Mixup

Mixup creates a convex combination of two images and their labels.

Given two examples $(x_i, y_i)$ and $(x_j, y_j)$, Mixup constructs

$$
\tilde{x} = \lambda x_i + (1-\lambda)x_j,
$$

$$
\tilde{y} = \lambda y_i + (1-\lambda)y_j.
$$

The value $\lambda$ is sampled from a beta distribution:

$$
\lambda \sim \operatorname{Beta}(\alpha,\alpha).
$$

Because the label becomes a soft target, the loss must support soft labels.

```python
import torch
import torch.nn.functional as F

def mixup_batch(images, labels, num_classes, alpha=0.2):
    batch_size = images.size(0)

    lam = torch.distributions.Beta(alpha, alpha).sample().item()
    perm = torch.randperm(batch_size, device=images.device)

    mixed_images = lam * images + (1.0 - lam) * images[perm]

    y1 = F.one_hot(labels, num_classes=num_classes).float()
    y2 = F.one_hot(labels[perm], num_classes=num_classes).float()
    mixed_labels = lam * y1 + (1.0 - lam) * y2

    return mixed_images, mixed_labels
```

For soft labels, use cross-entropy written directly:

```python
def soft_cross_entropy(logits, soft_targets):
    log_probs = F.log_softmax(logits, dim=1)
    return -(soft_targets * log_probs).sum(dim=1).mean()
```

Training step:

```python
images, labels = images.to(device), labels.to(device)

images, soft_labels = mixup_batch(
    images,
    labels,
    num_classes=num_classes,
    alpha=0.2,
)

logits = model(images)
loss = soft_cross_entropy(logits, soft_labels)
```

Mixup usually improves calibration and reduces overconfidence. It may reduce peak training accuracy because the labels are deliberately softened.

### CutMix

CutMix replaces a rectangular region of one image with a region from another image. The target label is mixed according to the area of the patch.

If the patch covers fraction $r$ of the image, the mixed label is approximately

$$
\tilde{y} = (1-r)y_i + r y_j.
$$

A minimal implementation:

```python
import torch
import torch.nn.functional as F

def rand_bbox(width, height, lam):
    cut_ratio = torch.sqrt(torch.tensor(1.0 - lam)).item()
    cut_w = int(width * cut_ratio)
    cut_h = int(height * cut_ratio)

    cx = torch.randint(width, size=(1,)).item()
    cy = torch.randint(height, size=(1,)).item()

    x1 = max(cx - cut_w // 2, 0)
    y1 = max(cy - cut_h // 2, 0)
    x2 = min(cx + cut_w // 2, width)
    y2 = min(cy + cut_h // 2, height)

    return x1, y1, x2, y2

def cutmix_batch(images, labels, num_classes, alpha=1.0):
    batch_size, _, height, width = images.shape

    lam = torch.distributions.Beta(alpha, alpha).sample().item()
    perm = torch.randperm(batch_size, device=images.device)

    x1, y1, x2, y2 = rand_bbox(width, height, lam)

    mixed_images = images.clone()
    mixed_images[:, :, y1:y2, x1:x2] = images[perm, :, y1:y2, x1:x2]

    patch_area = (x2 - x1) * (y2 - y1)
    image_area = width * height
    lam_adjusted = 1.0 - patch_area / image_area

    y1_onehot = F.one_hot(labels, num_classes=num_classes).float()
    y2_onehot = F.one_hot(labels[perm], num_classes=num_classes).float()

    mixed_labels = (
        lam_adjusted * y1_onehot
        + (1.0 - lam_adjusted) * y2_onehot
    )

    return mixed_images, mixed_labels
```

CutMix is often strong for object classification because it preserves local image structure better than Mixup. It can be less suitable when labels depend on global structure.

### RandAugment and AutoAugment

Manual augmentation policies require choosing operations and strengths. Automated policies search or define a compact policy space.

AutoAugment searches for augmentation policies on a dataset. RandAugment simplifies the approach using two main parameters: number of operations and magnitude.

Torchvision includes RandAugment:

```python
from torchvision.transforms import RandAugment

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])
```

RandAugment is useful as a strong default for larger datasets. For small or domain-specific datasets, it should be validated carefully. Some operations may violate label semantics.

### Augmentation Strength

Augmentation has a strength parameter, even when it is implicit. Weak augmentation may fail to regularize the model. Strong augmentation may create unrealistic or mislabeled data.

A useful progression is:

| Setting | Example policy |
|---|---|
| None | Resize, center crop, normalize |
| Weak | Random crop, horizontal flip |
| Moderate | Crop, flip, color jitter, random erasing |
| Strong | RandAugment, Mixup, CutMix |
| Very strong | RandAugment plus Mixup/CutMix plus heavy erasing |

Validation performance determines whether an augmentation helps. Training loss will often increase under stronger augmentation. That is not necessarily bad. The relevant question is whether validation metrics improve.

### Augmentation and Class Imbalance

Augmentation can help minority classes by increasing variation. However, ordinary random augmentation does not change the class distribution unless sampling is adjusted.

For imbalanced datasets, use augmentation together with one of the following:

| Method | Effect |
|---|---|
| Weighted sampler | Samples minority classes more often |
| Class-weighted loss | Penalizes minority-class errors more |
| Class-specific augmentation | Applies stronger augmentation to minority classes |
| Balanced batches | Controls class counts per batch |

A weighted sampler can be used in PyTorch:

```python
from torch.utils.data import WeightedRandomSampler

class_counts = torch.bincount(torch.tensor(train_set.targets))
class_weights = 1.0 / class_counts.float()

sample_weights = torch.tensor([
    class_weights[label] for label in train_set.targets
])

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True,
)
```

Then use:

```python
train_loader = DataLoader(
    train_set,
    batch_size=64,
    sampler=sampler,
    num_workers=4,
)
```

When `sampler` is used, do not also set `shuffle=True`.

### Test-Time Augmentation

Test-time augmentation applies several deterministic or random transforms at inference time and averages predictions.

```python
@torch.no_grad()
def predict_tta(model, image, transforms_list, device):
    model.eval()

    probs_sum = None

    for transform in transforms_list:
        x = transform(image).unsqueeze(0).to(device)
        logits = model(x)
        probs = torch.softmax(logits, dim=1)

        if probs_sum is None:
            probs_sum = probs
        else:
            probs_sum += probs

    probs_mean = probs_sum / len(transforms_list)
    confidence, pred = probs_mean.max(dim=1)

    return pred.item(), confidence.item()
```

Test-time augmentation can improve accuracy slightly, but it increases inference cost. It should be used only when latency permits.

### Debugging Augmentation

Augmentation bugs are common because transformed images may still have valid tensor shapes while being semantically wrong.

Always visualize transformed examples.

```python
import matplotlib.pyplot as plt

images, labels = next(iter(train_loader))

img = images[0]
img = img.permute(1, 2, 0)

plt.imshow(img)
plt.title(str(labels[0].item()))
plt.axis("off")
plt.show()
```

If the image was normalized, unnormalize before plotting:

```python
def unnormalize(img, mean, std):
    mean = torch.tensor(mean).view(3, 1, 1)
    std = torch.tensor(std).view(3, 1, 1)
    return img * std + mean
```

Then:

```python
img = unnormalize(
    images[0].cpu(),
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225],
)

img = img.clamp(0, 1).permute(1, 2, 0)
plt.imshow(img)
```

Visual inspection catches label-breaking crops, excessive color shifts, wrong channel order, and normalization mistakes.

### Practical Recipes

For a small natural image dataset:

```python
transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.02),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
```

For a medium dataset:

```python
transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    transforms.RandomErasing(p=0.25),
])
```

For a large dataset:

```python
transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    transforms.RandomErasing(p=0.25),
])
```

For fine-grained classification, use conservative crops and color changes. For robust general object classification, use stronger augmentation. For medical or scientific images, design augmentation with domain knowledge.

### Summary

Data augmentation improves generalization by training the model on label-preserving variations of the data. The main principle is semantic validity: an augmentation is useful only when it preserves the class label.

In PyTorch, augmentation is usually implemented through torchvision transforms and, for batch-level methods, inside the training loop. Basic policies use crops, flips, and normalization. Stronger policies add color jitter, erasing, RandAugment, Mixup, and CutMix.

A good augmentation policy is empirical but constrained by domain knowledge. It should improve validation performance, preserve labels, and avoid hiding the signal needed for classification.

