qml.labs.phox

Phase optimization with JAX (PHOX)

CircuitConfig(gates, observables, n_samples, ...)

Configuration data for an IQP circuit simulation.

build_expval_func(config)

Factory that returns a function for computing expectation values.

bitflip_expval(generators, params, ops)

Compute expectation value for the Bitflip noise model.

train(optimizer, loss, stepsize, n_iters, ...)

Main training function.

training_iterator(optimizer, loss, stepsize, ...)

Generator that yields training results in batches of size 'unroll_steps'.

TrainingOptions([unroll_steps, val_kwargs, ...])

Configuration options for training.

TrainingResult(final_params, losses, ...)

Container for final training results.

BatchResult(params, state, key, key_val, ...)

Result from a single batch (unrolled chunk) of training steps.

Circuit construction utilities

create_lattice_gates(rows, cols[, distance, ...])

Generates gates based on nearest-neighbor interactions on a 2D lattice.

create_local_gates(n_qubits[, max_weight])

Generates a gate dictionary for the Phox simulator containing all gates whose generators have Pauli weight less or equal to max_weight.

create_random_gates(n_qubits, n_gates[, ...])

Generates a dictionary of random gates.

generate_pauli_observables(n_qubits[, ...])

Generates a batch of Pauli observables.

Workflow

pennylane.labs.phox provides a compact toolkit for constructing and simulating phase optimization circuits with JAX. The usual workflow is:

  1. Use helpers in pennylane.labs.phox.utils to assemble gates and observables.

  2. Configure the circuit with CircuitConfig.

  3. Build an expectation-value function with build_expval_func() and evaluate it for different parameter sets.

import jax

from pennylane.labs.phox import (
    CircuitConfig,
    build_expval_func,
    create_lattice_gates,
    generate_pauli_observables,
)

n_rows, n_cols = 3, 3
n_qubits = n_rows * n_cols

gates = create_lattice_gates(n_rows, n_cols, distance=1, max_weight=2)
observables = generate_pauli_observables(n_qubits, orders=[2], bases=["Z"])

key = jax.random.PRNGKey(0)
params = jax.random.uniform(key, shape=(len(gates),))

config = CircuitConfig(
    gates=gates,
    observables=observables,
    n_samples=4000,
    key=key,
    n_qubits=n_qubits,
)

expval_fn = jax.jit(build_expval_func(config))
expvals, std_errs = expval_fn(params)

Training

Below is a small training loop that minimizes the sum of all two-body Z correlators on the same 3x3 lattice. The loss function reuses the compiled expval_fn from above.

import jax.numpy as jnp

from pennylane.labs.phox import TrainingOptions, train

def loss_fn(current_params):
    expvals, _ = expval_fn(current_params)
    return jnp.sum(expvals)

result = train(
    optimizer="Adam",
    loss=loss_fn,
    stepsize=0.05,
    n_iters=200,
    loss_kwargs={"params": params},
    options=TrainingOptions(unroll_steps=10, random_state=1234),
)

print("Final loss:", float(result.losses[-1]))
print("Optimized parameters:", result.final_params)