Source code for pennylane.labs.phox.training
# Copyright 2026 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training utilities for Phox.
"""
import time
from dataclasses import dataclass
from functools import partial
from inspect import signature
from typing import Any, Callable, Iterator, NamedTuple
import jax
import jax.numpy as jnp
try:
import jaxopt
import optax
from tqdm import tqdm
except (ModuleNotFoundError, ImportError) as import_error:
pass
[docs]
@dataclass
class TrainingOptions:
"""
Configuration options for training.
Args:
unroll_steps (int): How many optimization steps to run on the GPU before yielding
control back to Python. Higher = Faster. Lower = More interactive/granular logging.
Defaults to 1 (slow, good for debugging).
val_kwargs (dict[str, Any] | None): Arguments for the loss function to be used during validation.
convergence_interval (int): Number of steps over which to check for convergence.
Defaults to 100.
random_state (int): Seed for PRNGKey.
opt_jit (bool): Whether to JIT the optimizer creation (usually False is fine).
"""
unroll_steps: int = 1
val_kwargs: dict[str, Any] | None = None
convergence_interval: int = 100
random_state: int = 666
opt_jit: bool = False
[docs]
class TrainingResult(NamedTuple):
"""Container for final training results."""
final_params: jnp.ndarray
losses: jnp.ndarray
val_losses: jnp.ndarray
run_time: float
[docs]
class BatchResult(NamedTuple):
"""Result from a single batch (unrolled chunk) of training steps."""
params: jnp.ndarray
state: jnp.ndarray
key: jax.Array
key_val: jax.Array
losses: jnp.ndarray
val_losses: jnp.ndarray
def _prepare_loss_function(loss: Callable) -> Callable:
"""
Wraps the loss function to ensure it accepts a 'key' argument.
If the original function doesn't accept 'key', we consume and ignore it.
Args:
loss (Callable): The original loss function.
Returns:
Callable: A wrapped loss function that accepts a ``key`` argument, regardless of whether the original did.
"""
if "key" in signature(loss).parameters:
return loss
return lambda params, key, **kwargs: loss(params, **kwargs)
def _create_optimizer(name: str, loss_fn: Callable, stepsize: float, opt_jit: bool):
"""
Create the JAX optimizer instance.
Args:
name (str): The name of the optimizer to create ('GradientDescent', 'Adam', 'BFGS').
loss_fn (Callable): The loss function to minimize.
stepsize (float): The step size (learning rate) for the optimizer.
opt_jit (bool): Whether to JIT compile the optimizer update.
Returns:
jaxopt.Optimizer: An instance of the requested optimizer.
Raises:
ValueError: If the optimizer name is not recognized.
"""
# pylint: disable=import-outside-toplevel
if name == "GradientDescent":
return jaxopt.GradientDescent(loss_fn, stepsize=stepsize, verbose=False, jit=opt_jit)
if name == "Adam":
return jaxopt.OptaxSolver(loss_fn, optax.adam(stepsize), verbose=False, jit=opt_jit)
if name == "BFGS":
return jaxopt.BFGS(loss_fn, verbose=False, jit=opt_jit)
raise ValueError(
f"Optimizer {name} not recognized. Choose from 'Adam', 'BFGS', 'GradientDescent'."
)
def _check_convergence(losses: jnp.ndarray, convergence_interval: int) -> bool:
"""
Check for convergence based on loss history.
Args:
losses (jnp.ndarray): Array of recorded loss values.
convergence_interval (int): number of steps to look back when comparing means.
Returns:
bool: True if converged, False otherwise.
"""
recent = losses[-convergence_interval:]
previous = losses[-2 * convergence_interval : -convergence_interval]
avg1 = jnp.mean(recent)
avg2 = jnp.mean(previous)
std1 = jnp.std(recent)
# Stop if improvement is statistically insignificant or loss increases
cond1 = jnp.abs(avg2 - avg1) <= std1 / jnp.sqrt(convergence_interval) / 2
cond2 = avg1 > avg2
return cond1 or cond2
def _update_step_scan(carry, _, opt, loss_fn, loss_kwargs, val_kwargs, validation, optimizer_name):
# pylint: disable=too-many-arguments
"""
Single step update logic to be scanned.
Args:
carry (list): List of carried state [params, state, key, key_val].
_ (Any): Unused variable to accommodate `jax.lax.scan`
opt (jaxopt.Optimizer): The optimizer instance.
loss_fn (Callable): The loss function.
loss_kwargs (dict[str, Any]): Arguments for the loss function.
val_kwargs (dict[str, Any]): Arguments for the validation function.
validation (bool): Whether validation is enabled.
optimizer_name (str): Name of the optimizer.
Returns:
tuple[list, list]: Tuple containing:
- The new carry state [params, state, key2, key2_val].
- The stacked list [training_loss, validation_loss].
"""
params, state, key, key_val = carry
key1, key2 = jax.random.split(key, 2)
key1_val, key2_val = jax.random.split(key_val, 2)
params, state = opt.update(params, state, **loss_kwargs, key=key1)
v_loss = loss_fn(params, **val_kwargs, key=key1_val) if validation else 0.0
if optimizer_name == "GradientDescent":
t_loss = loss_fn(params, **loss_kwargs, key=key1)
else:
t_loss = state.value
return [params, state, key2, key2_val], [t_loss, v_loss]
[docs]
def training_iterator(
optimizer: str,
loss: Callable,
stepsize: float,
loss_kwargs: dict[str, Any],
options: TrainingOptions | None = None,
) -> Iterator[BatchResult]:
"""
Generator that yields training results in batches of size 'unroll_steps'.
Args:
optimizer (str): Name of the optimizer to use. Options are "GradientDescent", "Adam", or "BFGS".
loss (Callable): The loss function.
stepsize (float): The learning rate.
loss_kwargs (dict[str, Any]): Arguments to pass to the loss function.
options (TrainingOptions | None): Configuration options for training. See :class:`TrainingOptions` for further details.
Yields:
Iterator[BatchResult]: An iterator over batch results. See :class:`BatchResult` for further details.
"""
options = options or TrainingOptions()
unroll_steps = max(1, options.unroll_steps)
wrapped_loss = _prepare_loss_function(loss)
opt = _create_optimizer(optimizer, wrapped_loss, stepsize, options.opt_jit)
fixed_loss_kwargs = loss_kwargs.copy()
params_init = fixed_loss_kwargs.pop("params")
validation = options.val_kwargs is not None
fixed_val_kwargs = options.val_kwargs.copy() if validation else {}
key = jax.random.PRNGKey(options.random_state)
key1, key2 = jax.random.split(key, 2)
key = fixed_loss_kwargs.pop("key", key1)
key_val = fixed_val_kwargs.pop("key", key2)
state = opt.init_state(params_init, **fixed_loss_kwargs, key=key)
params = params_init
scan_fn = partial(
_update_step_scan,
opt=opt,
loss_fn=wrapped_loss,
loss_kwargs=fixed_loss_kwargs,
val_kwargs=fixed_val_kwargs,
validation=validation,
optimizer_name=optimizer,
)
@jax.jit
def step_batch(params, state, key, key_val):
carry = [params, state, key, key_val]
carry, [chunk_losses, chunk_vals] = jax.lax.scan(scan_fn, carry, jnp.arange(unroll_steps))
return BatchResult(
params=carry[0],
state=carry[1],
key=carry[2],
key_val=carry[3],
losses=chunk_losses,
val_losses=chunk_vals,
)
while True:
result = step_batch(params, state, key, key_val)
params = result.params
state = result.state
key = result.key
key_val = result.key_val
yield result
[docs]
def train(
optimizer: str,
loss: Callable,
stepsize: float,
n_iters: int,
loss_kwargs: dict[str, Any],
options: TrainingOptions | None = None,
) -> TrainingResult:
# pylint: disable=too-many-arguments
"""
Main training function.
Manages the loop, accumulation of history, and convergence checks.
Args:
optimizer (str): Name of the optimizer to use. Options are "GradientDescent", "Adam", or "BFGS".
loss (Callable): The loss function.
stepsize (float): The learning rate.
n_iters (int): Total number of training iterations.
loss_kwargs (dict[str, Any]): Arguments to pass to the loss function.
options (TrainingOptions | None): Configuration options for training. See :class:`TrainingOptions` for further details.
Returns:
TrainingResult: The results of the training process, including final parameters and loss history.
See :class:`TrainingResult` for further details.
"""
options = options or TrainingOptions()
unroll_steps = max(1, options.unroll_steps)
total_batches = (n_iters + unroll_steps - 1) // unroll_steps
start_time = time.time()
loss_acc = []
val_loss_acc = []
converged = False
final_params = loss_kwargs["params"]
iterator = training_iterator(
optimizer=optimizer, loss=loss, stepsize=stepsize, loss_kwargs=loss_kwargs, options=options
)
with tqdm(total=n_iters, desc="Training Progress") as pbar:
for i, batch_result in enumerate(iterator):
if i >= total_batches:
break
final_params = batch_result.params
loss_acc.append(batch_result.losses)
val_loss_acc.append(batch_result.val_losses)
curr_loss = batch_result.losses[-1]
pbar.set_postfix(
{"loss": f"{curr_loss:.6f}", "elapsed": f"{time.time() - start_time:.1f}s"}
)
pbar.update(unroll_steps)
current_step = (i + 1) * unroll_steps
# Check based on validation loss if available, else training loss
metric_acc = val_loss_acc if options.val_kwargs else loss_acc
history_needed = 2 * options.convergence_interval
if current_step > history_needed:
recent_history = jnp.concatenate(
metric_acc[-10:]
) # Grab last 10 chunks (heuristic)
if len(recent_history) >= history_needed:
if _check_convergence(recent_history, options.convergence_interval):
print(f"Training converged after {current_step} steps")
converged = True
break
if not converged:
print(f"Training has not converged after {n_iters} steps")
all_losses = jnp.concatenate(loss_acc) if loss_acc else jnp.array([])
all_val_losses = jnp.concatenate(val_loss_acc) if val_loss_acc else jnp.array([])
if len(all_losses) > n_iters:
all_losses = all_losses[:n_iters]
all_val_losses = all_val_losses[:n_iters]
return TrainingResult(
final_params=final_params,
losses=all_losses,
val_losses=all_val_losses,
run_time=time.time() - start_time,
)
_modules/pennylane/labs/phox/training
Download Python script
Download Notebook
View on GitHub