PennyLane Advanced Free 10/26 in series 70 minutes

Quantum Generative Adversarial Networks with PennyLane

Implement a patch quantum GAN in PennyLane with a parametric quantum generator and classical neural network discriminator. Covers adversarial training, mode collapse, and how entanglement contributes to output diversity.

What you'll learn

  • quantum GAN
  • QGAN
  • PennyLane
  • generative models
  • quantum ML

Prerequisites

  • Strong Python skills
  • Solid quantum computing foundations
  • Linear algebra and complex numbers

Generative Adversarial Networks, Quantumly

Generative adversarial networks (GANs) pit two models against each other: a generator that tries to produce realistic samples and a discriminator that tries to distinguish real from generated samples. As training progresses, the generator learns to fool the discriminator. The equilibrium is a Nash game where the generator matches the target distribution and the discriminator can do no better than random guessing.

Quantum GANs (QGANs) replace the generator with a parametric quantum circuit. The motivation is that quantum circuits can represent probability distributions that are hard to sample from classically, potentially enabling generative models with a computational advantage. The Born rule connects circuit amplitudes to probabilities, giving a natural probabilistic output without softmax.

This tutorial implements the patch QGAN architecture from Huang et al. (2021). The generator is a PennyLane circuit producing real-valued samples via expectation values, and the discriminator is a classical PyTorch network. You will train the full adversarial loop, diagnose mode collapse, and examine how entanglement affects output diversity.

Architecture Overview

The patch QGAN works as follows:

  • The quantum generator takes a random noise input (encoded as rotation angles) and a set of trainable parameters. It outputs a vector of expectation values in [-1, 1] that represents one “patch” of the generated sample.
  • Multiple patches are concatenated to build a higher-dimensional output.
  • The classical discriminator receives either a real sample or a generated sample and outputs a probability of being real.
  • Generator and discriminator are trained alternately with opposing objectives.

Setup

import numpy as np
import torch
import torch.nn as nn
import pennylane as qml
from pennylane import numpy as pnp
import matplotlib.pyplot as plt
from scipy.stats import wasserstein_distance

Target Distribution

Use a bimodal Gaussian mixture as the 1D target distribution. This is a classic stress test for GANs because a model that collapses to one mode fails visibly.

def sample_real_data(n_samples):
    """Sample from a bimodal Gaussian mixture."""
    rng = np.random.default_rng(42)
    labels = rng.choice([0, 1], size=n_samples, p=[0.4, 0.6])
    samples = np.where(
        labels == 0,
        rng.normal(loc=-2.0, scale=0.5, size=n_samples),
        rng.normal(loc=2.0, scale=0.8, size=n_samples),
    )
    return samples.astype(np.float32)


# Visualize target
real_samples = sample_real_data(2000)
plt.figure(figsize=(8, 4))
plt.hist(real_samples, bins=60, density=True, alpha=0.7, color='steelblue', label='Real data')
plt.xlabel("Value")
plt.ylabel("Density")
plt.title("Target Distribution (Bimodal Gaussian)")
plt.legend()
plt.tight_layout()
plt.savefig("qgan_target.png", dpi=150)
plt.show()

Quantum Generator

The generator uses 4 qubits and outputs 4 expectation values, one per qubit. The noise input is encoded via RY gates; trainable parameters control RZ rotations and entangling CNOT layers.

n_qubits = 4
n_layers = 3
n_output = n_qubits  # 4-dimensional output per patch

dev = qml.device("default.qubit", wires=n_qubits)


@qml.qnode(dev, interface="torch", diff_method="parameter-shift")
def quantum_generator(noise, weights):
    """
    Quantum generator circuit.
    
    Args:
        noise: shape (n_qubits,) -- random angles from noise prior
        weights: shape (n_layers, n_qubits, 3) -- trainable parameters
    
    Returns:
        List of expectation values <Z> for each qubit.
    """
    # Encode noise input
    for i in range(n_qubits):
        qml.RY(noise[i], wires=i)

    # Trainable ansatz: alternating rotation + entanglement layers
    for layer in range(n_layers):
        for i in range(n_qubits):
            qml.RY(weights[layer, i, 0], wires=i)
            qml.RZ(weights[layer, i, 1], wires=i)
            qml.RY(weights[layer, i, 2], wires=i)
        # Entangling CNOT ring
        for i in range(n_qubits):
            qml.CNOT(wires=[i, (i + 1) % n_qubits])

    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]


def generate_samples(weights, batch_size, noise_scale=np.pi):
    """Generate a batch of samples from the quantum generator."""
    rng = torch.default_generator
    noise_batch = torch.FloatTensor(batch_size, n_qubits).uniform_(
        -noise_scale, noise_scale
    )
    
    outputs = []
    for noise_vec in noise_batch:
        out = quantum_generator(noise_vec, weights)
        # Stack expectation values into a 1D tensor, then take mean for 1D generation
        patch = torch.stack(out).float()
        outputs.append(patch.mean().unsqueeze(0))  # Collapse to scalar for 1D problem
    
    return torch.cat(outputs)  # shape: (batch_size, 1)


# Test generator output shape
init_weights = torch.randn(n_layers, n_qubits, 3, requires_grad=True) * 0.1
test_output = generate_samples(init_weights, batch_size=5)
print(f"Generator output shape: {test_output.shape}")
print(f"Sample values: {test_output.detach().numpy().flatten()}")

Classical Discriminator

The discriminator is a shallow fully-connected network. It must be expressive enough to distinguish real from fake but not so deep that it dominates the training signal.

class Discriminator(nn.Module):
    def __init__(self, input_dim=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.LeakyReLU(0.2),
            nn.Linear(32, 16),
            nn.LeakyReLU(0.2),
            nn.Linear(16, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x)


discriminator = Discriminator(input_dim=1)
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters())}")

Adversarial Training Loop

The training alternates between updating the discriminator (to better distinguish real from fake) and updating the generator (to better fool the discriminator). Standard binary cross-entropy loss is used.

def train_qgan(
    n_epochs=200,
    batch_size=64,
    lr_disc=2e-3,
    lr_gen=1e-2,
    n_disc_steps=2,
    log_interval=20,
):
    """Full QGAN training loop."""
    # Initialize
    weights = (torch.randn(n_layers, n_qubits, 3) * 0.1).detach().requires_grad_(True)
    disc = Discriminator(input_dim=1)
    
    gen_optimizer = torch.optim.Adam([weights], lr=lr_gen)
    disc_optimizer = torch.optim.Adam(disc.parameters(), lr=lr_disc)
    criterion = nn.BCELoss()
    
    history = {
        'disc_loss': [],
        'gen_loss': [],
        'wasserstein': [],
        'epoch': [],
    }
    
    real_pool = torch.FloatTensor(sample_real_data(10000)).unsqueeze(1)
    
    for epoch in range(n_epochs):
        # --- Discriminator update ---
        for _ in range(n_disc_steps):
            disc_optimizer.zero_grad()
            
            # Sample real batch
            idx = torch.randint(0, len(real_pool), (batch_size,))
            real_batch = real_pool[idx]
            
            # Sample fake batch
            fake_batch = generate_samples(weights, batch_size).detach().unsqueeze(1)
            
            real_labels = torch.ones(batch_size, 1) * 0.9   # Label smoothing
            fake_labels = torch.zeros(batch_size, 1) + 0.1
            
            disc_real = disc(real_batch)
            disc_fake = disc(fake_batch)
            
            loss_real = criterion(disc_real, real_labels)
            loss_fake = criterion(disc_fake, fake_labels)
            disc_loss = (loss_real + loss_fake) / 2
            
            disc_loss.backward()
            disc_optimizer.step()
        
        # --- Generator update ---
        gen_optimizer.zero_grad()
        
        fake_batch = generate_samples(weights, batch_size).unsqueeze(1)
        # Generator wants discriminator to output 1 for its samples
        gen_target = torch.ones(batch_size, 1)
        disc_on_fake = disc(fake_batch)
        gen_loss = criterion(disc_on_fake, gen_target)
        
        gen_loss.backward()
        gen_optimizer.step()
        
        # Logging
        if epoch % log_interval == 0 or epoch == n_epochs - 1:
            with torch.no_grad():
                eval_fake = generate_samples(weights, 10).numpy()
                eval_real = real_pool[:10, 0].numpy()
                w_dist = wasserstein_distance(eval_real, eval_fake)
            
            history['disc_loss'].append(disc_loss.item())
            history['gen_loss'].append(gen_loss.item())
            history['wasserstein'].append(w_dist)
            history['epoch'].append(epoch)
            
            print(f"Epoch {epoch:4d} | D_loss: {disc_loss.item():.4f} | "
                  f"G_loss: {gen_loss.item():.4f} | W_dist: {w_dist:.4f}")
    
    return weights, disc, history


print("Starting QGAN training...")
final_weights, final_disc, history = train_qgan(n_epochs=5, batch_size=4, n_disc_steps=1, log_interval=1)

Evaluating the Trained Generator

def evaluate_generator(weights, real_samples, n_eval=2000):
    """Compare generated and real distributions."""
    with torch.no_grad():
        generated = generate_samples(weights, n_eval).numpy()
    
    w_dist = wasserstein_distance(real_samples[:n_eval], generated)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Distribution comparison
    axes[0].hist(real_samples[:n_eval], bins=60, density=True, alpha=0.6,
                 color='steelblue', label='Real')
    axes[0].hist(generated, bins=60, density=True, alpha=0.6,
                 color='coral', label='Generated')
    axes[0].set_xlabel("Value")
    axes[0].set_ylabel("Density")
    axes[0].set_title(f"Distribution Comparison (W={w_dist:.4f})")
    axes[0].legend()
    
    # Training curves
    axes[1].plot(history['epoch'], history['disc_loss'], label='Discriminator Loss')
    axes[1].plot(history['epoch'], history['gen_loss'], label='Generator Loss')
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Loss")
    axes[1].set_title("Training Curves")
    axes[1].legend()
    
    plt.tight_layout()
    plt.savefig("qgan_results.png", dpi=150)
    plt.show()
    
    print(f"Final Wasserstein distance: {w_dist:.6f}")
    print(f"Generated mean: {generated.mean():.4f} (real: {real_samples[:n_eval].mean():.4f})")
    print(f"Generated std:  {generated.std():.4f} (real: {real_samples[:n_eval].std():.4f})")
    return generated


generated = evaluate_generator(final_weights, real_samples, n_eval=50)

Mode Collapse and How to Detect It

Mode collapse is the main failure mode of GANs. The generator converges to one of the target modes and ignores others. In our bimodal case, a collapsed generator would produce samples near -2.0 or near 2.0 but not both.

Watch for these signals:

  • Generator loss drops rapidly while discriminator loss rises; the discriminator is being fooled but the distribution is narrow.
  • The Wasserstein distance plateaus at a high value despite low GAN losses.
  • The output histogram shows a single narrow peak.
def diagnose_mode_collapse(generated, real_samples):
    """Check for mode collapse in a bimodal target."""
    left_mass = (generated < 0).mean()
    right_mass = (generated >= 0).mean()
    
    print("\n=== Mode Collapse Diagnosis ===")
    print(f"Generated mass below 0: {left_mass:.3f} (expected ~0.40)")
    print(f"Generated mass above 0: {right_mass:.3f} (expected ~0.60)")
    
    if min(left_mass, right_mass) < 0.10:
        print("WARNING: Possible mode collapse detected.")
        print("Try: lower generator learning rate, increase discriminator steps, add noise to inputs.")
    else:
        print("No mode collapse detected.")


diagnose_mode_collapse(generated, real_samples)

The Role of Entanglement in Output Diversity

The CNOT ring in the generator circuit creates entanglement between qubits. This has a concrete effect: entangled qubits produce correlated outputs, and the joint Born rule distribution over all qubits can capture correlations that independent qubits cannot.

Without entanglement (removing the CNOT ring), the 4 output expectation values are statistically independent. With the ring, the joint state spans a larger portion of the Hilbert space. For 1D generation, the benefit is modest. For higher-dimensional generation (images, molecules), entanglement is essential for modeling correlations between dimensions.

You can test this by removing the entanglement layer:

@qml.qnode(dev, interface="torch", diff_method="parameter-shift")
def quantum_generator_no_entanglement(noise, weights):
    """Generator without CNOT entanglement -- for comparison."""
    for i in range(n_qubits):
        qml.RY(noise[i], wires=i)
    for layer in range(n_layers):
        for i in range(n_qubits):
            qml.RY(weights[layer, i, 0], wires=i)
            qml.RZ(weights[layer, i, 1], wires=i)
            qml.RY(weights[layer, i, 2], wires=i)
        # No CNOT layer
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

In practice, the entangled generator converges faster and achieves lower Wasserstein distance on multimodal targets, because the entanglement introduces correlations that help the Born rule distribution cover the full target support.

Comparison to Classical GANs

Classical GANs for 1D data use neural network generators with perhaps 100-1000 parameters. The quantum generator here uses n_layers * n_qubits * 3 = 36 trainable parameters, far fewer. This is both a strength (fewer parameters to optimize in a barren plateau regime) and a limitation (less expressive for complex high-dimensional distributions).

Practical QGANs today are most useful for distributions that have a natural quantum structure: quantum state tomography, Hamiltonian ground state learning, and generating training data for quantum algorithms. For classical image or text generation, classical GANs remain far ahead.

The barren plateau problem deserves mention. For deep quantum circuits with many qubits, gradients vanish exponentially with system size. The patch QGAN architecture was designed to mitigate this by keeping each sub-circuit small (4 qubits here) and composing patches. This is why the architecture uses n_layers=3 rather than 10 or 20.

Was this tutorial helpful?