Press ESC to exit fullscreen
πŸ“– Lesson ⏱️ 210 minutes

Generative Adversarial Networks

GANs for generating synthetic data and images

Introduction

Generative Adversarial Networks (GANs) are a class of generative models where two neural networks, a generator and a discriminator, compete in a game-theoretic setting to produce realistic synthetic data.

GANs are widely used for:

βœ… Image generation and enhancement.
βœ… Data augmentation.
βœ… Super-resolution tasks.
βœ… Art and creative generation.


How GANs Work

1️⃣ Generator (G): Learns to generate realistic-looking data from random noise.
2️⃣ Discriminator (D): Learns to distinguish between real data and fake data generated by G.

The two networks train simultaneously:

  • G tries to fool D into classifying generated data as real.
  • D tries to correctly classify real vs fake data.

This adversarial process leads to G improving its outputs to become more realistic over time.


Challenges with GANs

  • Training instability: GANs can be difficult to train due to non-convergence.
  • Mode collapse: The generator may produce limited varieties of outputs.
  • Hyperparameter sensitivity.

Example: Implementing a Simple GAN with PyTorch

We will build a basic GAN to generate 1D synthetic data (sine wave distribution) for clear conceptual understanding.


1️⃣ Install and Import Libraries

pip install torch matplotlib
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

2️⃣ Define Generator and Discriminator

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(10, 16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

3️⃣ Training the GAN

G = Generator()
D = Discriminator()

criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=0.001)
optimizer_D = optim.Adam(D.parameters(), lr=0.001)

real_data = torch.sin(torch.linspace(0, 3.14, 100)).unsqueeze(1)

for epoch in range(2000):
    # Train Discriminator
    D.zero_grad()
    real_labels = torch.ones(real_data.size(0), 1)
    output_real = D(real_data)
    loss_real = criterion(output_real, real_labels)

    noise = torch.randn(real_data.size(0), 10)
    fake_data = G(noise)
    fake_labels = torch.zeros(real_data.size(0), 1)
    output_fake = D(fake_data.detach())
    loss_fake = criterion(output_fake, fake_labels)

    loss_D = loss_real + loss_fake
    loss_D.backward()
    optimizer_D.step()

    # Train Generator
    G.zero_grad()
    output = D(fake_data)
    loss_G = criterion(output, real_labels)
    loss_G.backward()
    optimizer_G.step()

    if epoch % 500 == 0:
        print(f"Epoch {epoch} - Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")

4️⃣ Visualize Generated Data

noise = torch.randn(100, 10)
generated = G(noise).detach().numpy()

plt.hist(generated, bins=20, alpha=0.7, label='Generated')
plt.hist(real_data.numpy(), bins=20, alpha=0.7, label='Real')
plt.legend()
plt.title('GAN Generated Data vs Real Data')
plt.show()

Conclusion

βœ… GANs enable the generation of synthetic data that closely resembles real data.
βœ… They are a powerful tool in creative and scientific applications.
βœ… Understanding the adversarial training process is key to building GANs effectively.


What’s Next?

βœ… Experiment with DCGANs for image generation.
βœ… Learn Wasserstein GANs for stable training.
βœ… Explore CycleGANs for image-to-image translation.


Join the SuperML Community to share your GAN projects and collaborate with others in advanced deep learning.


Happy Generating! πŸŒ€