Skip to content

Add support for simulation over classical distributions #1682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions qualtran/_infra/composite_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,42 @@ def _binst_to_cxns(
return pred_cxns, succ_cxns


def _get_soquet(
binst: 'BloqInstance',
reg_name: str,
right: bool = False,
idx: Tuple[int, ...] = (),
*,
binst_graph: nx.DiGraph,
) -> 'Soquet':
"""Retrieve a soquet given identifying information.

We can uniquely address a Soquet by the arguments to this function.

Args:
binst: The bloq instance associated with the desired soquet.
reg_name: The name of the register associated with the desired soquet.
right: If False, get the input, left soquet. Otherwise: the right, output soquet
idx: The index of the soquet within a multidimensional register, or the empty
tuple for basic registers.
"""
preds, succs = _binst_to_cxns(binst, binst_graph=binst_graph)
if right:
for suc in succs:
me = suc.left
if me.reg.name == reg_name and me.idx == idx:
return me
else:
for pred in preds:
me = pred.right
if me.reg.name == reg_name and me.idx == idx:
return me

raise ValueError(
f"Could not find the requested soquet with {binst=}, {reg_name=}, {right=}, {idx=}"
)


def _cxns_to_soq_dict(
regs: Iterable[Register],
cxns: Iterable[Connection],
Expand Down
26 changes: 25 additions & 1 deletion qualtran/_infra/composite_bloq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
Soquet,
SoquetT,
)
from qualtran._infra.composite_bloq import _create_binst_graph, _get_dangling_soquets
from qualtran._infra.composite_bloq import _create_binst_graph, _get_dangling_soquets, _get_soquet
from qualtran._infra.data_types import BQUInt, QAny, QBit, QFxp, QUInt
from qualtran.bloqs.basic_gates import CNOT, IntEffect, ZeroEffect
from qualtran.bloqs.bookkeeping import Join
Expand Down Expand Up @@ -619,6 +619,30 @@ def test_decompose_symbolic_register_shape_raises():
bloq.decompose_bloq()


class LeftRightSoquets(Bloq):
"""Bloq that outputs a random bit for testing."""

@property
def signature(self) -> 'Signature':
return Signature(
[Register('in', QBit(), side=Side.LEFT), Register('out', QBit(), side=Side.RIGHT)]
)


def test_get_soquet():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

bloq = LeftRightSoquets()
cbloq = bloq.as_composite_bloq()
binst = BloqInstance(bloq, 0)
binst_graph = cbloq._binst_graph # pylint: disable=protected-access

soquet = _get_soquet(binst=binst, reg_name='out', right=True, binst_graph=binst_graph)
assert soquet.reg.name == 'out'
soquet = _get_soquet(binst=binst, reg_name='in', right=False, binst_graph=binst_graph)
assert soquet.reg.name == 'in'
with pytest.raises(ValueError, match='Could not find the requested soquet'):
_ = _get_soquet(binst=binst, reg_name='in', right=True, binst_graph=binst_graph)


@pytest.mark.notebook
def test_notebook():
qlt_testing.execute_notebook('composite_bloq')
169 changes: 160 additions & 9 deletions qualtran/simulation/classical_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Functionality for the `Bloq.call_classically(...)` protocol."""
import abc
import itertools
from typing import (
Any,
Expand All @@ -28,6 +29,7 @@
Union,
)

import attrs
import networkx as nx
import numpy as np
import sympy
Expand All @@ -43,13 +45,13 @@
Signature,
Soquet,
)
from qualtran._infra.composite_bloq import _binst_to_cxns
from qualtran._infra.composite_bloq import _binst_to_cxns, _get_soquet

if TYPE_CHECKING:
from qualtran import CompositeBloq, QCDType

ClassicalValT = Union[int, np.integer, NDArray[np.integer]]
ClassicalValRetT = Union[int, np.integer, NDArray[np.integer]]
ClassicalValRetT = Union[int, np.integer, NDArray[np.integer], 'ClassicalValDistribution']


def _numpy_dtype_from_qlt_dtype(dtype: 'QCDType') -> Type:
Expand Down Expand Up @@ -106,6 +108,88 @@ def _get_in_vals(
return arg


@attrs.frozen(hash=False)
class ClassicalValDistribution:
"""This class represents a distribution of classical values.

Use this if the bloq has performed a measurement or other projection
that has resulted in a mixed state of purely classical values.

Args:
a: An array of choices, or `np.arange` if an integer is given.
This is the `a` parameter to `np.random.Generator.choice()`.
p: An array of probabilities. If not supplied, the uniform distribution is assumed.
This is the `p` parameter to `np.random.Generator.choice()`.
"""

a: Union[int, np.typing.ArrayLike]
p: Optional[np.typing.ArrayLike] = None


class _ClassicalValHandler(metaclass=abc.ABCMeta):
"""An internal class for returning a random classical value.

Implmentors should write the get() function which returns a random
choice of values."""

@abc.abstractmethod
def get(self, binst: 'BloqInstance', distribution: ClassicalValDistribution) -> Any: ...


class _RandomClassicalValHandler(_ClassicalValHandler):
"""Returns a random classical value using a random number generator."""

def __init__(self, rng: 'np.random.Generator'):
self._gen = rng

def get(self, binst, distribution: ClassicalValDistribution):
return self._gen.choice(distribution.a, p=distribution.p) # type:ignore[arg-type]


class _FixedClassicalValHandler(_ClassicalValHandler):
"""Returns a random classical value using a fixed value per bloq instance.

Useful for deterministic testing.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Args:
binst_i_to_val: mapping from BloqInstance.i instance indices to the fixed classical value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


Args:
binst_i_to_val: mapping from BloqInstance.i instance indices
to the fixed classical value.
"""

def __init__(self, binst_i_to_val: Dict[int, Any]):
self._binst_i_to_val = binst_i_to_val

def get(self, binst, distribution: ClassicalValDistribution):
return self._binst_i_to_val[binst.i]


class _BannedClassicalValHandler(_ClassicalValHandler):
"""Used when random classical value is not able to be performed."""

def get(self, binst: 'BloqInstance', distribution: ClassicalValDistribution) -> Any:
raise ValueError(
f"{binst} has non-deterministic classical action."
"Cannot simulate with classical values."
)


@attrs.frozen
class MeasurementPhase:
"""Sentinel value for phases based on measurement outcomes:

This can be returned from `Bloq.basis_state_phase`
if a phase should be applied based on a measurement outcome.
This can be used in special circumstances to verify measurement-based uncomputation (MBUC).

Args:
reg_name: Name of the register
idx: Index of the register wire(s).
"""

reg_name: str
idx: Tuple[int, ...] = ()


class ClassicalSimState:
"""A mutable class for classically simulating composite bloqs.

Expand All @@ -122,6 +206,8 @@ class ClassicalSimState:
binst graph.
vals: A mapping of input register name to classical value to serve as inputs to the
procedure.
random_handler: The classical random number handler to use for use in
measurement-based outcomes (e.g. MBUC).

Attributes:
soq_assign: An assignment of soquets to classical values. We store the classical state
Expand All @@ -138,10 +224,12 @@ def __init__(
signature: 'Signature',
binst_graph: nx.DiGraph,
vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]],
random_handler: '_ClassicalValHandler' = _BannedClassicalValHandler(),
):
self._signature = signature
self._binst_graph = binst_graph
self._binst_iter = nx.topological_sort(self._binst_graph)
self._random_handler = random_handler

# Keep track of each soquet's bit array. Initialize with LeftDangle
self.soq_assign: Dict[Soquet, ClassicalValT] = {}
Expand Down Expand Up @@ -206,6 +294,9 @@ def _update_assign_from_vals(

else:
# `val` is one value.
if isinstance(val, ClassicalValDistribution):
val = self._random_handler.get(binst, val)

reg.dtype.assert_valid_classical_val(val, debug_str)
soq = Soquet(binst, reg)
self.soq_assign[soq] = val
Expand Down Expand Up @@ -321,6 +412,8 @@ class directly for more fine-grained control.
soq_assign: An assignment of soquets to classical values.
last_binst: A record of the last bloq instance we processed during simulation.
phase: The current phase of the simulation state.
random_handler: The classical random number handler to use for use in
measurement-based outcomes (e.g. MBUC).
"""

def __init__(
Expand All @@ -330,26 +423,48 @@ def __init__(
vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]],
*,
phase: complex = 1.0,
random_handler: '_ClassicalValHandler',
):
super().__init__(signature=signature, binst_graph=binst_graph, vals=vals)
super().__init__(
signature=signature, binst_graph=binst_graph, vals=vals, random_handler=random_handler
)
_assert_valid_phase(phase)
self.phase = phase

@classmethod
def from_cbloq(
cls, cbloq: 'CompositeBloq', vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]]
cls,
cbloq: 'CompositeBloq',
vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]],
rng: Optional['np.random.Generator'] = None,
fixed_random_vals: Optional[Dict[int, Any]] = None,
) -> 'PhasedClassicalSimState':
"""Initiate a classical simulation from a CompositeBloq.

Args:
cbloq: The composite bloq
vals: A mapping of input register name to classical value to serve as inputs to the
procedure.
rng: A random number generator to use for classical random values, such a np.random.
fixed_random_vals: A dictionary of bloq instances to values to perform fixed calculation
for classical values.

Returns:
A new classical sim state.
"""
return cls(signature=cbloq.signature, binst_graph=cbloq._binst_graph, vals=vals)
rnd_handler: _ClassicalValHandler
if rng is not None:
rnd_handler = _RandomClassicalValHandler(rng=rng)
elif fixed_random_vals is not None:
rnd_handler = _FixedClassicalValHandler(binst_i_to_val=fixed_random_vals)
else:
rnd_handler = _BannedClassicalValHandler()
return cls(
signature=cbloq.signature,
binst_graph=cbloq._binst_graph,
vals=vals,
random_handler=rnd_handler,
)

def _binst_basis_state_phase(self, binst, in_vals):
"""Call `basis_state_phase` on a given bloq instance.
Expand All @@ -359,7 +474,26 @@ def _binst_basis_state_phase(self, binst, in_vals):
"""
bloq = binst.bloq
bloq_phase = bloq.basis_state_phase(**in_vals)
if bloq_phase is not None:
if isinstance(bloq_phase, MeasurementPhase):
# In this special case, there is a coupling between the classical result and the
# phase result (because the classical result is stochastic). We look up the measurement
# result and apply a phase if it is `1`.
meas_result = self.soq_assign[
_get_soquet(
binst=binst,
reg_name=bloq_phase.reg_name,
right=True,
idx=bloq_phase.idx,
binst_graph=self._binst_graph,
)
]
if meas_result == 1:
# Measurement result of 1, phase of -1
self.phase *= -1.0
else:
# Measurement result of 0, phase of +1
pass
elif bloq_phase is not None:
_assert_valid_phase(bloq_phase)
self.phase *= bloq_phase
else:
Expand All @@ -371,6 +505,9 @@ def call_cbloq_classically(
signature: Signature,
vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]],
binst_graph: nx.DiGraph,
random_handler: '_ClassicalValHandler' = _RandomClassicalValHandler(
rng=np.random.default_rng()
),
) -> Tuple[Dict[str, ClassicalValT], Dict[Soquet, ClassicalValT]]:
"""Propagate `on_classical_vals` calls through a composite bloq's contents.

Expand All @@ -381,14 +518,16 @@ def call_cbloq_classically(
signature: The cbloq's signature for validating inputs
vals: Mapping from register name to classical values
binst_graph: The cbloq's binst graph.
random_handler: The classical random number handler to use for use in
measurement-based outcomes (e.g. MBUC).

Returns:
final_vals: A mapping from register name to output classical values
soq_assign: An assignment from each soquet to its classical value. Soquets
corresponding to thru registers will be mapped to the *output* classical
value.
"""
sim = ClassicalSimState(signature, binst_graph, vals)
sim = ClassicalSimState(signature, binst_graph, vals, random_handler)
final_vals = sim.simulate()
return final_vals, sim.soq_assign

Expand All @@ -398,7 +537,12 @@ def _assert_valid_phase(p: complex, atol: float = 1e-8):
raise ValueError(f"Phases must have unit modulus. Found {p}.")


def do_phased_classical_simulation(bloq: 'Bloq', vals: Mapping[str, 'ClassicalValT']):
def do_phased_classical_simulation(
bloq: 'Bloq',
vals: Mapping[str, 'ClassicalValT'],
rng: Optional['np.random.Generator'] = None,
fixed_random_vals: Optional[Dict[int, Any]] = None,
):
"""Do a phased classical simulation of the bloq.

This provides a simple interface to `PhasedClassicalSimState`. Advanced users
Expand All @@ -408,13 +552,20 @@ def do_phased_classical_simulation(bloq: 'Bloq', vals: Mapping[str, 'ClassicalVa
bloq: The bloq to simulate
vals: A mapping from input register name to initial classical values. The initial phase is
assumed to be 1.0.
rng: A numpy random generator (e.g. from `np.random.default_rng()`). This function
will use this generator to supply random values from certain phased-classical operations
like `MeasX`. If not supplied, classical measurements will use a random value.
fixed_random_vals: A dictionary of instance to values to perform fixed calculation
for classical values.

Returns:
final_vals: A mapping of output register name to final classical values.
phase: The final phase.
"""
cbloq = bloq.as_composite_bloq()
sim = PhasedClassicalSimState.from_cbloq(cbloq, vals=vals)
sim = PhasedClassicalSimState.from_cbloq(
cbloq, vals=vals, rng=rng, fixed_random_vals=fixed_random_vals
)
final_vals = sim.simulate()
phase = sim.phase
return final_vals, phase
Expand Down
Loading