# Variational Autoencoders

A variational autoencoder, or VAE, is an autoencoder with a probabilistic latent space. Instead of mapping an input $x$ to one fixed latent vector $z$, the encoder maps $x$ to a probability distribution over latent vectors.

A standard autoencoder computes

$$
z = f_\theta(x).
$$

A variational autoencoder computes

$$
q_\phi(z \mid x),
$$

where $q_\phi$ is an approximate posterior distribution. The model samples $z$ from this distribution and decodes it back to the data space:

$$
z \sim q_\phi(z \mid x),
$$

$$
\hat{x} = g_\theta(z).
$$

The main purpose of a VAE is generative modeling. After training, we can sample a latent vector from a simple prior distribution and pass it through the decoder to generate new data.

### Latent Variable Models

A latent variable model assumes that each observed example $x$ is generated from an unobserved variable $z$.

The model defines a prior over latent variables:

$$
p(z).
$$

It also defines a likelihood, or decoder distribution:

$$
p_\theta(x \mid z).
$$

Together, these define the marginal probability of an observed input:

$$
p_\theta(x) =
\int p_\theta(x \mid z)p(z)\,dz.
$$

The integral sums over all possible latent explanations $z$. For deep neural decoders, this integral is usually intractable. VAEs solve this problem with approximate inference.

### Encoder as Approximate Inference

The true posterior distribution is

$$
p_\theta(z \mid x) =
\frac{p_\theta(x \mid z)p(z)}{p_\theta(x)}.
$$

This distribution tells us which latent variables could have generated $x$. But computing it requires $p_\theta(x)$, which contains the intractable integral above.

A VAE introduces an encoder distribution

$$
q_\phi(z \mid x)
$$

to approximate the true posterior:

$$
q_\phi(z \mid x) \approx p_\theta(z \mid x).
$$

For many VAEs, this approximate posterior is chosen to be a diagonal Gaussian:

$$
q_\phi(z \mid x) =
\mathcal{N}
\left(
z;
\mu_\phi(x),
\operatorname{diag}(\sigma_\phi^2(x))
\right).
$$

The encoder network outputs two vectors:

$$
\mu_\phi(x)
\quad \text{and} \quad
\log \sigma_\phi^2(x).
$$

These define the mean and variance of the latent distribution.

### The Prior

The most common VAE prior is a standard normal distribution:

$$
p(z) = \mathcal{N}(0,I).
$$

This prior gives the latent space a simple global structure. During generation, we sample

$$
z \sim \mathcal{N}(0,I)
$$

and decode it:

$$
x \sim p_\theta(x \mid z).
$$

The prior matters because it defines where valid latent codes should live. Without a prior, the latent space of an ordinary autoencoder may contain holes. Points between encoded examples may decode poorly. A VAE regularizes the latent space so that samples from the prior decode into plausible data.

### Evidence Lower Bound

The training objective of a VAE is the evidence lower bound, usually called the ELBO.

The log likelihood of a data point is

$$
\log p_\theta(x).
$$

Directly maximizing this quantity is difficult. Instead, the VAE maximizes a lower bound:

$$
\log p_\theta(x)
\ge
\mathbb{E}_{q_\phi(z\mid x)}
[
\log p_\theta(x\mid z)
] -
\mathrm{KL}
(
q_\phi(z\mid x)
\| p(z)
).
$$

The first term is the reconstruction term. It rewards latent samples that allow the decoder to explain $x$.

The second term is the KL regularization term. It keeps the encoder distribution close to the prior.

Thus the VAE objective has two forces:

| Term | Effect |
|---|---|
| Reconstruction term | Preserve information about $x$ |
| KL term | Regularize latent space toward $p(z)$ |

Training minimizes the negative ELBO:

$$
L = -
\mathbb{E}_{q_\phi(z\mid x)}
[
\log p_\theta(x\mid z)
]
+
\mathrm{KL}
(
q_\phi(z\mid x)
\| p(z)
).
$$

### Closed-Form KL for Gaussian Latents

For the common case

$$
q_\phi(z\mid x) =
\mathcal{N}
(\mu, \operatorname{diag}(\sigma^2))
$$

and

$$
p(z)=\mathcal{N}(0,I),
$$

the KL divergence has a closed form:

$$
\mathrm{KL}(q_\phi(z\mid x)\|p(z)) =
\frac{1}{2}
\sum_{j=1}^d
\left(
\mu_j^2
+
\sigma_j^2 -
\log \sigma_j^2 -
1
\right).
$$

This term is easy to compute in PyTorch when the encoder outputs `mu` and `logvar`:

```python
def kl_divergence(mu, logvar):
    return 0.5 * torch.sum(
        mu.pow(2) + logvar.exp() - logvar - 1,
        dim=1,
    )
```

The result has shape `[B]`, one KL value per example.

### Reparameterization Trick

Sampling from $q_\phi(z\mid x)$ creates a problem for gradient-based learning. The sampling operation seems to block gradients from flowing into the encoder parameters.

The reparameterization trick solves this by rewriting sampling as a deterministic transformation of noise:

$$
z = \mu + \sigma \odot \epsilon,
$$

where

$$
\epsilon \sim \mathcal{N}(0,I).
$$

Now randomness comes from $\epsilon$, while $\mu$ and $\sigma$ remain differentiable outputs of the encoder.

In PyTorch:

```python
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + std * eps
```

This is one of the central implementation details of VAEs.

### Minimal VAE in PyTorch

A simple fully connected VAE for flattened images can be written as follows:

```python
import torch
from torch import nn

class VAE(nn.Module):
    def __init__(self, input_dim: int, latent_dim: int):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
        )

        self.mu = nn.Linear(256, latent_dim)
        self.logvar = nn.Linear(256, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.mu(h), self.logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + std * eps

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar, z
```

For MNIST-like data, the input dimension is $784$, and the output uses `Sigmoid` because pixels are scaled to $[0,1]$.

### VAE Loss in PyTorch

The VAE loss combines reconstruction loss and KL divergence.

```python
import torch.nn.functional as F

def vae_loss(x_hat, x, mu, logvar):
    reconstruction = F.binary_cross_entropy(
        x_hat,
        x,
        reduction="sum",
    )

    kl = 0.5 * torch.sum(
        mu.pow(2) + logvar.exp() - logvar - 1
    )

    return reconstruction + kl
```

A full training step:

```python
model = VAE(input_dim=784, latent_dim=32)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

x = torch.rand(128, 784)

x_hat, mu, logvar, z = model(x)
loss = vae_loss(x_hat, x, mu, logvar)

optimizer.zero_grad()
loss.backward()
optimizer.step()
```

Many implementations average the loss over the batch instead of summing it. Both conventions are valid, but the learning rate and KL coefficient must be interpreted consistently.

### Beta-VAE

A beta-VAE modifies the VAE loss by weighting the KL term:

$$
L = -
\mathbb{E}_{q_\phi(z\mid x)}
[
\log p_\theta(x\mid z)
]
+
\beta
\mathrm{KL}
(
q_\phi(z\mid x)
\|p(z)
).
$$

When $\beta = 1$, this is the standard VAE objective.

When $\beta > 1$, the model is pushed harder toward the prior. This may encourage disentangled latent factors, but it often reduces reconstruction quality.

When $\beta < 1$, the model can preserve more information in $z$, but the latent space may become less regular.

In PyTorch:

```python
def beta_vae_loss(x_hat, x, mu, logvar, beta: float = 4.0):
    reconstruction = F.binary_cross_entropy(
        x_hat,
        x,
        reduction="sum",
    )

    kl = 0.5 * torch.sum(
        mu.pow(2) + logvar.exp() - logvar - 1
    )

    return reconstruction + beta * kl
```

### Generating New Samples

After training, generation no longer requires the encoder. We sample from the prior and decode:

```python
model.eval()

with torch.no_grad():
    z = torch.randn(16, 32)
    samples = model.decode(z)
```

If the model was trained on $28 \times 28$ images, the generated samples can be reshaped:

```python
images = samples.reshape(16, 1, 28, 28)
```

This works because the decoder has learned a mapping from latent space to data space.

### Latent Interpolation

A useful way to inspect a VAE is to interpolate between two latent codes.

Given two inputs $x_a$ and $x_b$, encode them into means:

$$
\mu_a = \mu_\phi(x_a),
\quad
\mu_b = \mu_\phi(x_b).
$$

Then create intermediate latent vectors:

$$
z_\alpha = (1-\alpha)\mu_a + \alpha \mu_b,
\quad
0 \le \alpha \le 1.
$$

Decode each $z_\alpha$. A well-structured latent space should produce smooth transitions.

```python
def interpolate(model, x_a, x_b, steps: int = 8):
    model.eval()

    with torch.no_grad():
        mu_a, _ = model.encode(x_a)
        mu_b, _ = model.encode(x_b)

        alphas = torch.linspace(0, 1, steps).to(x_a.device)
        zs = [(1 - a) * mu_a + a * mu_b for a in alphas]
        zs = torch.cat(zs, dim=0)

        return model.decode(zs)
```

Interpolation is qualitative, but it often reveals whether the latent space has learned useful continuity.

### Posterior Collapse

A common VAE failure mode is posterior collapse. This happens when the decoder ignores the latent variable.

In posterior collapse,

$$
q_\phi(z\mid x) \approx p(z).
$$

The KL term becomes small, but the latent code carries little information about the input.

This occurs most often when the decoder is very powerful. For example, an autoregressive text decoder may model the data well without using $z$.

Common mitigation strategies include:

| Method | Purpose |
|---|---|
| KL annealing | Slowly increase KL weight during training |
| Free bits | Require each latent dimension to carry minimum information |
| Weaker decoder | Force dependence on $z$ |
| Skip connections with care | Avoid bypassing the latent code |
| Better posterior families | Improve approximate inference |

Posterior collapse is one of the main practical difficulties in training VAEs for language.

### Reconstruction Sharpness

VAEs often produce blurrier images than GANs or diffusion models. One reason is the likelihood model.

If the decoder is trained with mean squared error, it tends to predict conditional averages. When several outputs are plausible, the average may be blurry.

For binary cross-entropy, outputs may also appear smooth because the decoder predicts pixel probabilities rather than deterministic pixels.

More expressive decoders, perceptual losses, hierarchical latents, vector quantization, and diffusion decoders can improve sample quality.

### Disentanglement

A disentangled representation separates independent factors of variation into different latent dimensions.

For example, one latent dimension may control rotation, another may control stroke thickness, and another may control style.

Beta-VAE and related methods try to encourage disentanglement by increasing pressure toward the prior. Stronger regularization can make latent dimensions more independent.

However, disentanglement is difficult. Without supervision or inductive bias, there is no guarantee that a model will discover the factors humans expect. Many equivalent latent coordinate systems can explain the same data.

### Summary

A variational autoencoder is a probabilistic autoencoder. The encoder defines an approximate posterior $q_\phi(z\mid x)$, the decoder defines a likelihood $p_\theta(x\mid z)$, and the prior $p(z)$ gives structure to the latent space.

The VAE objective maximizes the ELBO, which balances reconstruction quality against KL regularization. The reparameterization trick allows gradients to pass through stochastic latent sampling.

VAEs provide a principled bridge between autoencoders and probabilistic generative models. They are useful for representation learning, latent interpolation, uncertainty-aware generation, and as building blocks for more advanced generative systems.

