PennyLane Advanced Free 16/26 in series 35 min

Compiling Quantum Programs with PennyLane Catalyst

Use PennyLane Catalyst to JIT-compile hybrid quantum-classical programs with JAX, enabling fast execution on both simulators and quantum hardware.

What you'll learn

  • PennyLane
  • Catalyst
  • JIT compilation
  • JAX
  • hybrid quantum-classical

Prerequisites

  • Strong Python skills
  • Solid quantum computing foundations
  • Linear algebra and complex numbers

Overview

Every time a standard PennyLane circuit executes, Python re-dispatches each gate through the PennyLane pipeline individually. For a single forward pass, that overhead is negligible. For a VQE training loop that runs 2,000 gradient steps on a 10-qubit ansatz, that per-gate Python dispatch accumulates into minutes of wasted wall-clock time. The circuit logic never changes between iterations; only the parameter values do. The interpreter is doing redundant work on every single call.

PennyLane Catalyst eliminates this overhead by JIT-compiling your entire hybrid quantum-classical program into a native binary. Catalyst traces the program once, lowers it through MLIR (the same compiler infrastructure behind LLVM), and produces machine code that runs without touching the Python interpreter on subsequent calls. The result is typically a 5x to 50x speedup on repeated circuit evaluations, with the largest gains on small-to-medium circuits where Python overhead dominates execution time. On a 4-qubit VQE loop with 200 optimization steps, benchmarks consistently show a 10x to 30x reduction in total runtime compared to standard PennyLane dispatch.

This tutorial walks through installation, basic compilation, control flow inside JIT regions, a full VQE training loop, MLIR inspection, performance measurement, and the most common mistakes that trip up new Catalyst users.

Installation

Catalyst ships as a separate package that pairs with PennyLane and JAX. You need all four packages.

pip install pennylane pennylane-catalyst jax jaxlib

Catalyst requires a compatible version of JAX. If you hit version conflicts, pin the versions explicitly:

pip install pennylane==0.39.0 pennylane-catalyst==0.9.0 jax==0.4.35 jaxlib==0.4.35

Verify the installation:

import pennylane as qml
import catalyst
print(f"PennyLane version: {qml.__version__}")
print(f"Catalyst version:  {catalyst.__version__}")

If the import succeeds without errors, your environment is ready.

Your First Compiled Circuit

The @qml.qjit decorator is the entry point. Place it on any QNode or hybrid function to compile it.

import pennylane as qml
import jax.numpy as jnp

dev = qml.device("lightning.qubit", wires=3)

@qml.qjit
@qml.qnode(dev)
def compiled_circuit(x: float):
    qml.RX(x, wires=0)          # Rotate qubit 0 around X
    qml.CNOT(wires=[0, 1])      # Entangle qubits 0 and 1
    qml.RY(x * 0.5, wires=1)   # Parameterized Y rotation on qubit 1
    return qml.expval(qml.PauliZ(0))

# First call: triggers compilation (expect a brief pause)
result = compiled_circuit(0.7)
print(f"Result: {result:.6f}")

# Second call: runs the compiled binary directly (fast)
result2 = compiled_circuit(1.4)
print(f"Result: {result2:.6f}")

The first invocation triggers the full compilation pipeline: tracing, MLIR lowering, LLVM code generation, and linking. This takes anywhere from 0.5 to 5 seconds depending on circuit complexity. Every subsequent call with different parameter values reuses the compiled binary and executes in microseconds to milliseconds. The compilation cost is paid once.

Note the decorator order: @qml.qjit must be the outermost decorator, wrapping @qml.qnode. Reversing the order causes an error because qjit needs to trace the QNode, not the other way around.

Hybrid Classical Logic Inside JIT

Catalyst supports classical control flow inside the compiled region using three primitives: qml.for_loop, qml.while_loop, and qml.cond. These compile down to real loop and branch constructs in MLIR, rather than being unrolled or frozen at trace time.

@qml.qjit
@qml.qnode(dev)
def dynamic_circuit(params, n_layers: int):
    @qml.for_loop(0, n_layers, 1)
    def layer(i):
        qml.RY(params[i], wires=0)
        qml.RZ(params[i] * 0.5, wires=1)
        qml.CNOT(wires=[0, 1])

    layer()
    return qml.expval(qml.PauliZ(0))

import jax.numpy as jnp
params = jnp.array([0.3, 0.6, 0.9])
print(dynamic_circuit(params, 3))

The qml.for_loop(start, stop, step) construct takes the loop bounds and a callable body. Inside the body, the loop variable i is available as a traced integer that Catalyst can reason about at the IR level. The same pattern applies to qml.while_loop, which takes a condition function and a body function:

@qml.qjit
@qml.qnode(dev)
def adaptive_circuit(threshold: float):
    qml.Hadamard(wires=0)

    @qml.while_loop(lambda v: v < threshold)
    def repeat_until(v):
        qml.RY(v, wires=0)
        return v + 0.1  # increment the tracked value

    repeat_until(0.0)
    return qml.expval(qml.PauliZ(0))

print(adaptive_circuit(0.5))

For conditional logic, use qml.cond instead of Python if:

@qml.qjit
@qml.qnode(dev)
def conditional_circuit(x: float):
    qml.Hadamard(wires=0)

    # Apply different gates depending on x at runtime
    qml.cond(x > 0.5)(qml.RX)(x, wires=0)

    return qml.expval(qml.PauliZ(0))

Why Python Control Flow Breaks Traces

This section explains a subtlety that causes real bugs. Understanding it saves hours of debugging.

When Catalyst compiles a function, it does not execute the Python code in the normal sense. It traces the function: it runs the Python code once with abstract placeholder values, records what operations occur, and compiles that recorded trace into MLIR. This is the same approach JAX uses for jax.jit.

The consequence is that Python control flow is evaluated at trace time, not at runtime. Consider this example:

# WRONG: Python for-loop with a runtime variable
@qml.qjit
@qml.qnode(dev)
def broken_layers(params, n_layers: int):
    for i in range(n_layers):       # n_layers is abstract at trace time
        qml.RY(params[i], wires=0)
        qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(0))

This fails because range(n_layers) requires n_layers to be a concrete Python integer. At trace time, n_layers is an abstract tracer object; Python cannot iterate over it. You get a tracing error.

If n_layers were a compile-time constant (say, you hardcoded range(3)), the loop would be unrolled at trace time into three sequential blocks of gates. That works, but it produces a different compiled binary for each value of the constant, and it cannot handle dynamic values.

The fix is qml.for_loop, which compiles to a proper loop in MLIR:

# CORRECT: Catalyst for_loop handles runtime bounds
@qml.qjit
@qml.qnode(dev)
def working_layers(params, n_layers: int):
    @qml.for_loop(0, n_layers, 1)
    def apply_layer(i):
        qml.RY(params[i], wires=0)
        qml.CNOT(wires=[0, 1])

    apply_layer()
    return qml.expval(qml.PauliZ(0))

The same issue applies to if statements. A Python if is evaluated once during tracing. Whichever branch the tracer takes is baked into the compiled output permanently:

# WRONG: Python if freezes the branch at trace time
@qml.qjit
@qml.qnode(dev)
def broken_conditional(x: float):
    if x > 0.5:          # x is abstract; this condition is frozen
        qml.RX(x, wires=0)
    else:
        qml.RY(x, wires=0)
    return qml.expval(qml.PauliZ(0))

With a concrete tracer value, this might silently compile and always take one branch regardless of the runtime input. With an abstract value, it raises an error. Either way, the behavior is not what you want.

The fix is qml.cond, which creates a proper conditional branch in the compiled IR:

# CORRECT: qml.cond creates a runtime branch
@qml.qjit
@qml.qnode(dev)
def working_conditional(x: float):
    qml.cond(x > 0.5)(qml.RX)(x, wires=0)
    return qml.expval(qml.PauliZ(0))

The rule of thumb: if the value controlling the flow could change between calls, use Catalyst primitives (qml.for_loop, qml.while_loop, qml.cond). If it is truly a compile-time constant that never changes, Python control flow is acceptable but produces unrolled/frozen code.

Compiling a VQE Training Loop

The most impactful use case for Catalyst is eliminating Python overhead from optimization loops. A VQE workflow evaluates the circuit and its gradient thousands of times with different parameters. Compiling both the forward pass and the gradient computation yields the largest speedups.

import jax
import jax.numpy as jnp

dev = qml.device("lightning.qubit", wires=4)

@qml.qjit
@qml.qnode(dev)
def vqe_circuit(params):
    # Single-qubit rotation layer
    for i in range(4):
        qml.RY(params[i], wires=i)
    # Entangling layer
    for i in range(3):
        qml.CNOT(wires=[i, i + 1])
    # Second rotation layer
    for i in range(4):
        qml.RZ(params[4 + i], wires=i)
    # Measure two-body correlation
    return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

# Compile the gradient function as well
grad_fn = qml.qjit(qml.grad(vqe_circuit))

params = jnp.zeros(8)
lr = 0.05

for step in range(100):
    g = grad_fn(params)
    params = params - lr * g
    if step % 20 == 0:
        energy = vqe_circuit(params)
        print(f"Step {step:3d} | energy = {energy:.6f}")

Note that the for i in range(4) loops inside the circuit are fine here because 4 is a compile-time constant. Catalyst unrolls them at trace time into four explicit gate operations. The outer optimization loop (for step in range(100)) runs in Python, which is correct: it is the driver that calls the compiled functions, not part of the compiled region.

Both vqe_circuit and grad_fn are compiled separately. Each call from the Python loop dispatches directly to the compiled binary, bypassing PennyLane’s Python-level gate dispatch entirely. The gradient is computed using the parameter-shift rule (or adjoint differentiation on lightning.qubit), and the entire differentiation pipeline is compiled.

You can also compile the gradient inline using catalyst.grad inside a qjit function for even tighter integration:

@qml.qjit
def compiled_vqe_step(params):
    energy = vqe_circuit(params)
    g = qml.grad(vqe_circuit)(params)
    new_params = params - 0.05 * g
    return new_params, energy

This compiles the forward pass, gradient, and parameter update into a single binary, eliminating even the Python overhead of calling two separate compiled functions.

Performance Comparison

To see the speedup concretely, time the same gradient loop with and without compilation. The following benchmark compares standard PennyLane execution against Catalyst-compiled execution on a 4-qubit VQE circuit.

import time
import pennylane as qml
import jax.numpy as jnp

dev = qml.device("lightning.qubit", wires=4)
n_steps = 200

# --- Standard PennyLane (no JIT) ---

@qml.qnode(dev)
def standard_circuit(params):
    for i in range(4):
        qml.RY(params[i], wires=i)
    for i in range(3):
        qml.CNOT(wires=[i, i + 1])
    for i in range(4):
        qml.RZ(params[4 + i], wires=i)
    return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

standard_grad = qml.grad(standard_circuit)

params_std = jnp.zeros(8)
start = time.time()
for step in range(n_steps):
    g = standard_grad(params_std)
    params_std = params_std - 0.05 * g
standard_time = time.time() - start
print(f"Standard PennyLane: {standard_time:.2f}s for {n_steps} steps")

# --- Catalyst-compiled ---

@qml.qjit
@qml.qnode(dev)
def compiled_circuit(params):
    for i in range(4):
        qml.RY(params[i], wires=i)
    for i in range(3):
        qml.CNOT(wires=[i, i + 1])
    for i in range(4):
        qml.RZ(params[4 + i], wires=i)
    return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

compiled_grad = qml.qjit(qml.grad(compiled_circuit))

# Warm up: trigger compilation before timing
params_jit = jnp.zeros(8)
_ = compiled_grad(params_jit)

start = time.time()
for step in range(n_steps):
    g = compiled_grad(params_jit)
    params_jit = params_jit - 0.05 * g
compiled_time = time.time() - start
print(f"Catalyst compiled:  {compiled_time:.2f}s for {n_steps} steps")
print(f"Speedup: {standard_time / compiled_time:.1f}x")

Typical results on a modern CPU with lightning.qubit:

Configuration200 stepsPer-step
Standard PennyLane~12s~60ms
Catalyst compiled~0.8s~4ms
Speedup~15x~15x

The speedup scales with the number of iterations. For a 10-step run, the compilation overhead dominates and you may see no benefit or even a slowdown. For 1,000+ steps, the amortized compilation cost becomes negligible and the speedup approaches the theoretical maximum for your circuit. The sweet spot where Catalyst pays for itself is typically around 20 to 50 iterations.

Larger circuits (more qubits, deeper ansatze) shift the bottleneck from Python overhead to actual simulation time. On a 20-qubit circuit, the simulation itself dominates and the relative speedup from eliminating Python dispatch shrinks to 2x to 5x. Catalyst still helps, but the gains are proportionally smaller because less time was being spent in Python to begin with.

Inspecting the MLIR Output

Catalyst lowers your quantum program to MLIR (Multi-Level Intermediate Representation), a compiler infrastructure originally developed as part of the LLVM project. MLIR provides a framework for defining and transforming intermediate representations at multiple levels of abstraction. In the Catalyst pipeline, your program passes through several stages:

  1. Python tracing captures the program structure as a JAX-compatible trace
  2. MLIR lowering converts the trace into a quantum-aware MLIR dialect
  3. Optimization passes apply compiler optimizations (gate fusion, constant folding, dead code elimination)
  4. LLVM code generation produces native machine code

The compiled output is a real native binary, not Python bytecode and not a JAX XLA computation that still carries tracing overhead. This is why the speedup is so substantial.

You can inspect the MLIR at various stages:

@qml.qjit
@qml.qnode(dev)
def inspectable_circuit(x: float):
    qml.RX(x, wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(0))

# Print the MLIR intermediate representation
print(inspectable_circuit.mlir)

The output shows the quantum operations represented as MLIR operations in Catalyst’s quantum dialect. You will see constructs like quantum.custom "RX" and quantum.custom "CNOT" alongside classical operations represented in standard MLIR dialects. This output is useful for:

  • Debugging compilation failures: if the MLIR looks wrong, you can trace back to the Python source
  • Verifying optimization: checking that the compiler fused gates or eliminated redundant operations
  • Understanding lowering: seeing exactly how your Python-level constructs translate to the IR

Interaction with jax.jit

Catalyst’s @qml.qjit and JAX’s @jax.jit are separate compilation systems with different backends. @qml.qjit compiles through Catalyst’s MLIR pipeline, while @jax.jit compiles through JAX’s XLA backend. They serve different purposes:

  • Use @qml.qjit for functions that contain quantum operations (QNodes)
  • Use @jax.jit for purely classical JAX computations that do not involve PennyLane circuits

Do not nest them. Wrapping a @qml.qjit function inside @jax.jit (or vice versa) creates conflicts between the two tracing systems. If your workflow has a classical preprocessing step followed by a quantum circuit, compile them separately:

import jax

@jax.jit
def classical_preprocess(raw_params):
    """Purely classical computation, compiled via JAX/XLA."""
    return jnp.tanh(raw_params) * jnp.pi

@qml.qjit
@qml.qnode(dev)
def quantum_step(params):
    """Quantum circuit, compiled via Catalyst/MLIR."""
    for i in range(3):
        qml.RY(params[i], wires=i)
    qml.CNOT(wires=[0, 1])
    qml.CNOT(wires=[1, 2])
    return qml.expval(qml.PauliZ(0))

# Call them sequentially from Python
raw = jnp.array([0.5, 1.2, -0.3])
processed = classical_preprocess(raw)
result = quantum_step(processed)

Both functions are compiled to native code, but through different compiler stacks. The Python driver orchestrates the calls, which is fine because the driver itself is trivially cheap.

Device Compatibility

Catalyst compiles circuits against a specific device backend. The device determines which gates are natively supported, how measurements work, and what simulation or hardware execution engine runs the compiled binary.

Currently supported devices:

  • lightning.qubit: the recommended default for CPU-based simulation. High-performance C++ state-vector simulator. Best choice for development, testing, and benchmarking.
  • lightning.kokkos: performance-portable simulator that can target CPUs or GPUs via the Kokkos framework. Use this when you need GPU acceleration for larger circuits.
  • Amazon Braket devices: select Braket backends support Catalyst compilation for hybrid workflows that target real quantum hardware.
import pennylane as qml

# Recommended: lightning.qubit for local development
dev_cpu = qml.device("lightning.qubit", wires=4)

# GPU acceleration for larger circuits
# dev_gpu = qml.device("lightning.kokkos", wires=20)

@qml.qjit
@qml.qnode(dev_cpu)
def my_circuit(x: float):
    qml.Hadamard(wires=0)
    qml.RX(x, wires=1)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(0))

print(f"Device: {dev_cpu.name}")
print(f"Result: {my_circuit(0.5):.6f}")

Not all PennyLane devices support Catalyst. The standard default.qubit device, for instance, does not work with @qml.qjit. If you try to compile a circuit targeting an unsupported device, Catalyst raises a clear error at compilation time. Always check the Catalyst documentation for the current list of supported backends when targeting hardware.

Common Mistakes

1. Using Python conditionals with dynamic values inside @qml.qjit

# WRONG: Python if with a traced value
@qml.qjit
def bad_branch(x):
    if x > 0:          # x is a tracer, not a concrete number
        return x * 2
    return x * 3

This either freezes the branch or raises a tracing error. Use jax.lax.cond for classical conditionals or qml.cond for quantum operations.

2. Forgetting that the first call compiles

The first invocation of a @qml.qjit function triggers compilation. If you are benchmarking, always run a warm-up call before timing:

compiled_fn = qml.qjit(my_circuit)

# Warm-up (compilation happens here)
_ = compiled_fn(0.0)

# Now time the actual execution
start = time.time()
for _ in range(1000):
    compiled_fn(0.5)
elapsed = time.time() - start

Without the warm-up call, your benchmark includes compilation time and wildly overstates the per-call cost.

3. Using NumPy instead of JAX NumPy

Inside a @qml.qjit function, all array operations must use jax.numpy, not standard numpy. Catalyst traces through JAX; standard NumPy arrays are opaque to the tracer.

import numpy as np
import jax.numpy as jnp

# WRONG: np.array is not traceable
@qml.qjit
def bad_preprocess(x):
    return np.sin(x) + np.cos(x)    # np operations escape the trace

# CORRECT: jnp.array operations are traced
@qml.qjit
def good_preprocess(x):
    return jnp.sin(x) + jnp.cos(x)  # jnp operations are captured in MLIR

This is easy to miss if you have import numpy as np at the top of your file and are accustomed to using it everywhere.

4. Passing Python lists instead of JAX arrays

Function arguments to @qml.qjit functions must be JAX-compatible types. Python lists and plain Python floats may work in some cases but cause subtle tracing issues in others.

# FRAGILE: Python list as argument
result = compiled_circuit([0.1, 0.2, 0.3, 0.4])

# SAFE: JAX array as argument
result = compiled_circuit(jnp.array([0.1, 0.2, 0.3, 0.4]))

Always convert inputs to jnp.array before passing them to compiled functions.

5. Recompilation from changing argument shapes

Catalyst caches the compiled binary based on the shapes and dtypes of the input arguments. If you change the shape of an input between calls, Catalyst recompiles:

compiled_fn(jnp.zeros(8))   # Compiles for shape (8,)
compiled_fn(jnp.zeros(10))  # Triggers a NEW compilation for shape (10,)

If your workflow requires variable-length inputs, pad them to a fixed size or restructure to avoid repeated recompilation.

Next Steps

With Catalyst compilation in your toolkit, several directions open up:

  • Explore catalyst.grad and catalyst.jacobian for compiling differentiation directly into the MLIR pipeline, giving you finer control over the differentiation method (parameter-shift, adjoint, finite-difference) within the compiled region.
  • Try lightning.kokkos for GPU-accelerated simulation. Once your VQE loop is compiled and correct on lightning.qubit, switching to the Kokkos backend is a one-line device change. This is where Catalyst shines for circuits beyond 15 qubits.
  • Profile your compiled programs using Catalyst’s built-in instrumentation. Set the environment variable CATALYST_INSTRUMENTATION=1 to see timing breakdowns for each compilation and execution phase.
  • Combine Catalyst with PennyLane’s circuit-cutting tools to decompose large circuits into smaller compiled fragments that can execute on limited-qubit hardware.
  • Read the Catalyst source and MLIR dialects. If you want to understand exactly how quantum operations are represented at the compiler level, the Catalyst GitHub repository contains the dialect definitions. Understanding the IR helps you write circuits that compile more efficiently.

Summary

Catalyst is the path from prototype to production-grade performance for hybrid quantum-classical programs. The core workflow is straightforward: annotate with @qml.qjit, use Catalyst control-flow primitives (qml.for_loop, qml.while_loop, qml.cond) for dynamic logic, and let the MLIR pipeline compile everything to native machine code. The payoff is 5x to 50x faster execution on repeated evaluations, which transforms VQE and QAOA training loops from painfully slow to practical. Inspect the MLIR when things go wrong, warm up before benchmarking, keep your arrays in JAX, and choose a supported device backend. That covers 90% of what you need to use Catalyst effectively.

Was this tutorial helpful?