Conv autoencoder (MNIST-scale)
import torch
import torch.nn as nn
class ConvAE(nn.Module):
def __init__(self, latent_dim=32):
super().__init__()
self.enc = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1), nn.ReLU(True),
nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(True),
nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(True),
nn.Flatten(),
nn.Linear(128 * 4 * 4, latent_dim),
)
self.dec_fc = nn.Linear(latent_dim, 128 * 4 * 4)
self.dec = nn.Sequential(
nn.Unflatten(1, (128, 4, 4)),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(True),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.ReLU(True),
nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1), nn.Sigmoid(),
)
def encode(self, x):
return self.enc(x)
def decode(self, z):
h = self.dec_fc(z)
return self.dec(h)
def forward(self, x):
return self.decode(self.encode(x))
Assumes input [B,1,32,32] so 3 stride-2 stages land at 4×4; adjust channels/sizes for your resolution.
Training step (MSE)
model = ConvAE(latent_dim=32)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
def train_step(x):
x_hat = model(x)
loss = nn.functional.mse_loss(x_hat, x)
opt.zero_grad()
loss.backward()
opt.step()
return loss.item()
Denoising variant
def add_noise(x, sigma=0.2):
return (x + sigma * torch.randn_like(x)).clamp(0, 1)
def denoise_step(x_clean):
x_noisy = add_noise(x_clean)
x_hat = model(x_noisy)
loss = nn.functional.mse_loss(x_hat, x_clean)
opt.zero_grad()
loss.backward()
opt.step()
return loss.item()
Latent inspection
model.eval()
with torch.no_grad():
z = model.encode(batch)
recon = model.decode(z)
Use t-SNE/UMAP on z for 2D visualization; anomalies often have high reconstruction error.
VAE (contrast)
Encoder outputs μ, log σ²; sample z = μ + σ * ε; decoder reconstructs. Loss = reconstruction + KL divergence to standard normal. Enables sampling new images from z ~ N(0,I)—standard AE does not define a proper generative density without extra assumptions.
Takeaways
- AE: compress and reconstruct; bottleneck controls capacity.
- Denoising: learn invariances to corruption.
- VAE: generative latent with KL regularization.