Generative Adversarial Networks (GANs)
GANs pit two networks against each other: a Generator that creates fake data and a Discriminator that detects fakes. Through this competition, both improve—generating hyper-realistic images, audio, and beyond.
Latent dim
64-512
Nash Equilibrium
Training goal
Mode Collapse
Key challenge
StyleGAN
1024x1024
The Adversarial Game
GANs are a minimax game between two players: Generator (G) and Discriminator (D). G tries to fool D by generating realistic samples. D tries to distinguish real from fake.
D(x) = probability that x is real. G learns to maximize D(G(z)).
Minimax objective: min_G max_D V(D,G) = E_x[log D(x)] + E_z[log(1 - D(G(z)))]
Vanilla GAN – The Original
Generator
Maps latent vector z to data space. Typically MLP with ReLU + sigmoid/tanh for images.
def build_generator(latent_dim=100):
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, input_dim=latent_dim),
tf.keras.layers.ReLU(),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(784, activation='tanh') # MNIST
])
return model
Discriminator
Binary classifier. Outputs probability of real image.
def build_discriminator():
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, input_shape=(784,)),
tf.keras.layers.LeakyReLU(alpha=0.2),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(1, activation='sigmoid')
])
return model
# Training loop (alternating updates)
for epoch in range(epochs):
for real_imgs, _ in dataloader:
batch_size = real_imgs.size(0)
z = torch.randn(batch_size, latent_dim)
# Train Discriminator
fake_imgs = G(z)
real_pred = D(real_imgs)
fake_pred = D(fake_imgs.detach())
d_loss = -torch.mean(torch.log(real_pred) + torch.log(1 - fake_pred))
d_loss.backward()
optimizer_D.step()
# Train Generator
z = torch.randn(batch_size, latent_dim)
fake_imgs = G(z)
fake_pred = D(fake_imgs)
g_loss = -torch.mean(torch.log(fake_pred))
g_loss.backward()
optimizer_G.step()
DCGAN – Convolutional GAN
DCGAN brought CNNs to GANs with key architectural guidelines that stabilized training.
Guidelines
- Replace pooling with strided conv (D) / fractional conv (G)
- BatchNorm in both G and D
- Remove fully connected layers
- ReLU in G (except output tanh)
- LeakyReLU in D
# Generator (DCGAN)
class DCGenerator(nn.Module):
def __init__(self, latent_dim=100):
super().__init__()
self.deconv = nn.Sequential(
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512), nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256), nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128), nn.ReLU(True),
nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z):
return self.deconv(z.view(z.size(0), z.size(1), 1, 1))
Training Challenges & Stabilization
Solution: WGAN, minibatch discrimination, unrolled GANs.
Solution: Use WGAN (Earth Mover distance), label smoothing, instance noise.
Wasserstein GAN (WGAN)
WGAN replaces the binary discriminator with a critic that scores realness. Uses Earth Mover distance, more stable training.
WGAN Loss
D_loss = E[D(fake)] - E[D(real)]
G_loss = -E[D(fake)]
Critic weights clipped to [-c, c] (WGAN) or gradient penalty (WGAN-GP).
def gradient_penalty(critic, real, fake, device):
batch_size, c, h, w = real.shape
epsilon = torch.rand(batch_size, 1, 1, 1).repeat(1, c, h, w).to(device)
interpolated = epsilon * real + (1 - epsilon) * fake
mixed_score = critic(interpolated)
gradient = torch.autograd.grad(
inputs=interpolated,
outputs=mixed_score,
grad_outputs=torch.ones_like(mixed_score),
create_graph=True,
retain_graph=True
)[0]
gradient = gradient.view(batch_size, -1)
gradient_norm = gradient.norm(2, dim=1)
gp = torch.mean((gradient_norm - 1) ** 2)
return gp
Conditional GAN (cGAN)
Both generator and discriminator receive additional condition (class label, text, image). Enables controlled generation.
Architecture
Concatenate condition y to z (G) and to x (D).
# Generator
z = torch.randn(batch_size, latent_dim)
y = one_hot(labels) # condition
gen_input = torch.cat([z, y], dim=1)
fake = G(gen_input)
# Discriminator
dis_input = torch.cat([image, y], dim=1)
score = D(dis_input)
Applications
- Pix2Pix: Image-to-image translation (edges→photo)
- CycleGAN: Unpaired translation (horse→zebra)
- Text-to-Image: Generate images from descriptions
- SRGAN: Super-resolution
StyleGAN & Progressive GANs
Progressive GAN
Start with low resolution (4x4), add layers as training progresses. Stabilizes high-res generation.
1024x1024 faces, cats, cars.
StyleGAN
Mapping network + AdaIN (adaptive instance normalization). Style mixing enables controllable synthesis (pose, identity, lighting).
Key idea: Noise injects stochastic variation (freckles, hair).
StyleGAN formula: w = MappingNetwork(z) → AdaIN(conv, w) → stochastic variation via noise. Separates high-level attributes from stochastic details.
Evaluating GANs – FID & Inception Score
Inception Score (IS)
Uses ImageNet-pretrained Inception. Measures:
- High confidence predictions (realistic)
- Diversity across samples
Criticism: Doesn't detect mode collapse if classes are diverse.
FID (Fréchet Inception Distance)
Compares statistics of real vs fake in Inception feature space.
FID = ||μ_r - μ_f||² + Tr(Σ_r + Σ_f - 2(Σ_rΣ_f)^½)
Lower is better Standard metric today.
# FID using torchmetrics
from torchmetrics.image.fid import FrechetInceptionDistance
fid = FrechetInceptionDistance(feature=2048)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)
print(f"FID: {fid.compute():.2f}")
Production-Ready GAN Implementations
import pytorch_lightning as pl
class GAN(pl.LightningModule):
def __init__(self, latent_dim=100):
super().__init__()
self.generator = Generator(latent_dim)
self.discriminator = Discriminator()
def training_step(self, batch, batch_idx, optimizer_idx):
real_imgs, _ = batch
z = torch.randn(real_imgs.size(0), self.latent_dim)
if optimizer_idx == 0: # train D
fake_imgs = self.generator(z)
d_loss = self.discriminator_loss(real_imgs, fake_imgs)
return d_loss
else: # train G
fake_imgs = self.generator(z)
g_loss = self.generator_loss(fake_imgs)
return g_loss
# TensorFlow GAN with custom training
@tf.function
def train_step(real_images):
z = tf.random.normal([BATCH_SIZE, latent_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
fake_images = generator(z, training=True)
real_output = discriminator(real_images, training=True)
fake_output = discriminator(fake_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
GAN Family Comparison
| Model | Key Idea | Stability | Quality | Use Case |
|---|---|---|---|---|
| Vanilla GAN | Minimax BCE | ❌ Low | ⭐ | Educational |
| DCGAN | Convolutional guidelines | ⭐⭐ | ⭐⭐ | Small images |
| WGAN-GP | Wasserstein + gradient penalty | ⭐⭐⭐⭐ | ⭐⭐⭐ | Default stable choice |
| cGAN | Conditional generation | ⭐⭐ | ⭐⭐⭐ | Labeled synthesis |
| StyleGAN | Style modulation + noise | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ | High-res faces |
| CycleGAN | Cycle consistency (unpaired) | ⭐⭐⭐ | ⭐⭐⭐ | Unpaired translation |