diff --git a/qualtran/_infra/composite_bloq.py b/qualtran/_infra/composite_bloq.py index 9832b9ffa..1d3f91e0b 100644 --- a/qualtran/_infra/composite_bloq.py +++ b/qualtran/_infra/composite_bloq.py @@ -512,6 +512,42 @@ def _binst_to_cxns( return pred_cxns, succ_cxns +def _get_soquet( + binst: 'BloqInstance', + reg_name: str, + right: bool = False, + idx: Tuple[int, ...] = (), + *, + binst_graph: nx.DiGraph, +) -> 'Soquet': + """Retrieve a soquet given identifying information. + + We can uniquely address a Soquet by the arguments to this function. + + Args: + binst: The bloq instance associated with the desired soquet. + reg_name: The name of the register associated with the desired soquet. + right: If False, get the input, left soquet. Otherwise: the right, output soquet + idx: The index of the soquet within a multidimensional register, or the empty + tuple for basic registers. + """ + preds, succs = _binst_to_cxns(binst, binst_graph=binst_graph) + if right: + for suc in succs: + me = suc.left + if me.reg.name == reg_name and me.idx == idx: + return me + else: + for pred in preds: + me = pred.right + if me.reg.name == reg_name and me.idx == idx: + return me + + raise ValueError( + f"Could not find the requested soquet with {binst=}, {reg_name=}, {right=}, {idx=}" + ) + + def _cxns_to_soq_dict( regs: Iterable[Register], cxns: Iterable[Connection], diff --git a/qualtran/_infra/composite_bloq_test.py b/qualtran/_infra/composite_bloq_test.py index 114b55083..83dfa234b 100644 --- a/qualtran/_infra/composite_bloq_test.py +++ b/qualtran/_infra/composite_bloq_test.py @@ -39,7 +39,7 @@ Soquet, SoquetT, ) -from qualtran._infra.composite_bloq import _create_binst_graph, _get_dangling_soquets +from qualtran._infra.composite_bloq import _create_binst_graph, _get_dangling_soquets, _get_soquet from qualtran._infra.data_types import BQUInt, QAny, QBit, QFxp, QUInt from qualtran.bloqs.basic_gates import CNOT, IntEffect, ZeroEffect from qualtran.bloqs.bookkeeping import Join @@ -619,6 +619,30 @@ def test_decompose_symbolic_register_shape_raises(): bloq.decompose_bloq() +class LeftRightSoquets(Bloq): + """Bloq that outputs a random bit for testing.""" + + @property + def signature(self) -> 'Signature': + return Signature( + [Register('in', QBit(), side=Side.LEFT), Register('out', QBit(), side=Side.RIGHT)] + ) + + +def test_get_soquet(): + bloq = LeftRightSoquets() + cbloq = bloq.as_composite_bloq() + binst = BloqInstance(bloq, 0) + binst_graph = cbloq._binst_graph # pylint: disable=protected-access + + soquet = _get_soquet(binst=binst, reg_name='out', right=True, binst_graph=binst_graph) + assert soquet.reg.name == 'out' + soquet = _get_soquet(binst=binst, reg_name='in', right=False, binst_graph=binst_graph) + assert soquet.reg.name == 'in' + with pytest.raises(ValueError, match='Could not find the requested soquet'): + _ = _get_soquet(binst=binst, reg_name='in', right=True, binst_graph=binst_graph) + + @pytest.mark.notebook def test_notebook(): qlt_testing.execute_notebook('composite_bloq') diff --git a/qualtran/simulation/classical_sim.py b/qualtran/simulation/classical_sim.py index 927c93e00..00f40f7cc 100644 --- a/qualtran/simulation/classical_sim.py +++ b/qualtran/simulation/classical_sim.py @@ -13,6 +13,7 @@ # limitations under the License. """Functionality for the `Bloq.call_classically(...)` protocol.""" +import abc import itertools from typing import ( Any, @@ -28,6 +29,7 @@ Union, ) +import attrs import networkx as nx import numpy as np import sympy @@ -43,13 +45,13 @@ Signature, Soquet, ) -from qualtran._infra.composite_bloq import _binst_to_cxns +from qualtran._infra.composite_bloq import _binst_to_cxns, _get_soquet if TYPE_CHECKING: from qualtran import CompositeBloq, QCDType ClassicalValT = Union[int, np.integer, NDArray[np.integer]] -ClassicalValRetT = Union[int, np.integer, NDArray[np.integer]] +ClassicalValRetT = Union[int, np.integer, NDArray[np.integer], 'ClassicalValDistribution'] def _numpy_dtype_from_qlt_dtype(dtype: 'QCDType') -> Type: @@ -106,6 +108,88 @@ def _get_in_vals( return arg +@attrs.frozen(hash=False) +class ClassicalValDistribution: + """This class represents a distribution of classical values. + + Use this if the bloq has performed a measurement or other projection + that has resulted in a mixed state of purely classical values. + + Args: + a: An array of choices, or `np.arange` if an integer is given. + This is the `a` parameter to `np.random.Generator.choice()`. + p: An array of probabilities. If not supplied, the uniform distribution is assumed. + This is the `p` parameter to `np.random.Generator.choice()`. + """ + + a: Union[int, np.typing.ArrayLike] + p: Optional[np.typing.ArrayLike] = None + + +class _ClassicalValHandler(metaclass=abc.ABCMeta): + """An internal class for returning a random classical value. + + Implmentors should write the get() function which returns a random + choice of values.""" + + @abc.abstractmethod + def get(self, binst: 'BloqInstance', distribution: ClassicalValDistribution) -> Any: ... + + +class _RandomClassicalValHandler(_ClassicalValHandler): + """Returns a random classical value using a random number generator.""" + + def __init__(self, rng: 'np.random.Generator'): + self._gen = rng + + def get(self, binst, distribution: ClassicalValDistribution): + return self._gen.choice(distribution.a, p=distribution.p) # type:ignore[arg-type] + + +class _FixedClassicalValHandler(_ClassicalValHandler): + """Returns a random classical value using a fixed value per bloq instance. + + Useful for deterministic testing. + + Args: + binst_i_to_val: mapping from BloqInstance.i instance indices + to the fixed classical value. + """ + + def __init__(self, binst_i_to_val: Dict[int, Any]): + self._binst_i_to_val = binst_i_to_val + + def get(self, binst, distribution: ClassicalValDistribution): + return self._binst_i_to_val[binst.i] + + +class _BannedClassicalValHandler(_ClassicalValHandler): + """Used when random classical value is not able to be performed.""" + + def get(self, binst: 'BloqInstance', distribution: ClassicalValDistribution) -> Any: + raise ValueError( + f"{binst} has non-deterministic classical action." + "Cannot simulate with classical values." + ) + + +@attrs.frozen +class MeasurementPhase: + """Sentinel value for phases based on measurement outcomes: + + This can be returned from `Bloq.basis_state_phase` + if a phase should be applied based on a measurement outcome. + This can be used in special circumstances to verify measurement-based uncomputation (MBUC). + + Args: + reg_name: Name of the register + idx: Index of the register wire(s). + """ + + reg_name: str + idx: Tuple[int, ...] = () + + class ClassicalSimState: """A mutable class for classically simulating composite bloqs. @@ -122,6 +206,8 @@ class ClassicalSimState: binst graph. vals: A mapping of input register name to classical value to serve as inputs to the procedure. + random_handler: The classical random number handler to use for use in + measurement-based outcomes (e.g. MBUC). Attributes: soq_assign: An assignment of soquets to classical values. We store the classical state @@ -138,10 +224,12 @@ def __init__( signature: 'Signature', binst_graph: nx.DiGraph, vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]], + random_handler: '_ClassicalValHandler' = _BannedClassicalValHandler(), ): self._signature = signature self._binst_graph = binst_graph self._binst_iter = nx.topological_sort(self._binst_graph) + self._random_handler = random_handler # Keep track of each soquet's bit array. Initialize with LeftDangle self.soq_assign: Dict[Soquet, ClassicalValT] = {} @@ -206,6 +294,9 @@ def _update_assign_from_vals( else: # `val` is one value. + if isinstance(val, ClassicalValDistribution): + val = self._random_handler.get(binst, val) + reg.dtype.assert_valid_classical_val(val, debug_str) soq = Soquet(binst, reg) self.soq_assign[soq] = val @@ -321,6 +412,8 @@ class directly for more fine-grained control. soq_assign: An assignment of soquets to classical values. last_binst: A record of the last bloq instance we processed during simulation. phase: The current phase of the simulation state. + random_handler: The classical random number handler to use for use in + measurement-based outcomes (e.g. MBUC). """ def __init__( @@ -330,14 +423,21 @@ def __init__( vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]], *, phase: complex = 1.0, + random_handler: '_ClassicalValHandler', ): - super().__init__(signature=signature, binst_graph=binst_graph, vals=vals) + super().__init__( + signature=signature, binst_graph=binst_graph, vals=vals, random_handler=random_handler + ) _assert_valid_phase(phase) self.phase = phase @classmethod def from_cbloq( - cls, cbloq: 'CompositeBloq', vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]] + cls, + cbloq: 'CompositeBloq', + vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]], + rng: Optional['np.random.Generator'] = None, + fixed_random_vals: Optional[Dict[int, Any]] = None, ) -> 'PhasedClassicalSimState': """Initiate a classical simulation from a CompositeBloq. @@ -345,11 +445,26 @@ def from_cbloq( cbloq: The composite bloq vals: A mapping of input register name to classical value to serve as inputs to the procedure. + rng: A random number generator to use for classical random values, such a np.random. + fixed_random_vals: A dictionary of bloq instances to values to perform fixed calculation + for classical values. Returns: A new classical sim state. """ - return cls(signature=cbloq.signature, binst_graph=cbloq._binst_graph, vals=vals) + rnd_handler: _ClassicalValHandler + if rng is not None: + rnd_handler = _RandomClassicalValHandler(rng=rng) + elif fixed_random_vals is not None: + rnd_handler = _FixedClassicalValHandler(binst_i_to_val=fixed_random_vals) + else: + rnd_handler = _BannedClassicalValHandler() + return cls( + signature=cbloq.signature, + binst_graph=cbloq._binst_graph, + vals=vals, + random_handler=rnd_handler, + ) def _binst_basis_state_phase(self, binst, in_vals): """Call `basis_state_phase` on a given bloq instance. @@ -359,7 +474,26 @@ def _binst_basis_state_phase(self, binst, in_vals): """ bloq = binst.bloq bloq_phase = bloq.basis_state_phase(**in_vals) - if bloq_phase is not None: + if isinstance(bloq_phase, MeasurementPhase): + # In this special case, there is a coupling between the classical result and the + # phase result (because the classical result is stochastic). We look up the measurement + # result and apply a phase if it is `1`. + meas_result = self.soq_assign[ + _get_soquet( + binst=binst, + reg_name=bloq_phase.reg_name, + right=True, + idx=bloq_phase.idx, + binst_graph=self._binst_graph, + ) + ] + if meas_result == 1: + # Measurement result of 1, phase of -1 + self.phase *= -1.0 + else: + # Measurement result of 0, phase of +1 + pass + elif bloq_phase is not None: _assert_valid_phase(bloq_phase) self.phase *= bloq_phase else: @@ -371,6 +505,9 @@ def call_cbloq_classically( signature: Signature, vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]], binst_graph: nx.DiGraph, + random_handler: '_ClassicalValHandler' = _RandomClassicalValHandler( + rng=np.random.default_rng() + ), ) -> Tuple[Dict[str, ClassicalValT], Dict[Soquet, ClassicalValT]]: """Propagate `on_classical_vals` calls through a composite bloq's contents. @@ -381,6 +518,8 @@ def call_cbloq_classically( signature: The cbloq's signature for validating inputs vals: Mapping from register name to classical values binst_graph: The cbloq's binst graph. + random_handler: The classical random number handler to use for use in + measurement-based outcomes (e.g. MBUC). Returns: final_vals: A mapping from register name to output classical values @@ -388,7 +527,7 @@ def call_cbloq_classically( corresponding to thru registers will be mapped to the *output* classical value. """ - sim = ClassicalSimState(signature, binst_graph, vals) + sim = ClassicalSimState(signature, binst_graph, vals, random_handler) final_vals = sim.simulate() return final_vals, sim.soq_assign @@ -398,7 +537,12 @@ def _assert_valid_phase(p: complex, atol: float = 1e-8): raise ValueError(f"Phases must have unit modulus. Found {p}.") -def do_phased_classical_simulation(bloq: 'Bloq', vals: Mapping[str, 'ClassicalValT']): +def do_phased_classical_simulation( + bloq: 'Bloq', + vals: Mapping[str, 'ClassicalValT'], + rng: Optional['np.random.Generator'] = None, + fixed_random_vals: Optional[Dict[int, Any]] = None, +): """Do a phased classical simulation of the bloq. This provides a simple interface to `PhasedClassicalSimState`. Advanced users @@ -408,13 +552,20 @@ def do_phased_classical_simulation(bloq: 'Bloq', vals: Mapping[str, 'ClassicalVa bloq: The bloq to simulate vals: A mapping from input register name to initial classical values. The initial phase is assumed to be 1.0. + rng: A numpy random generator (e.g. from `np.random.default_rng()`). This function + will use this generator to supply random values from certain phased-classical operations + like `MeasX`. If not supplied, classical measurements will use a random value. + fixed_random_vals: A dictionary of instance to values to perform fixed calculation + for classical values. Returns: final_vals: A mapping of output register name to final classical values. phase: The final phase. """ cbloq = bloq.as_composite_bloq() - sim = PhasedClassicalSimState.from_cbloq(cbloq, vals=vals) + sim = PhasedClassicalSimState.from_cbloq( + cbloq, vals=vals, rng=rng, fixed_random_vals=fixed_random_vals + ) final_vals = sim.simulate() phase = sim.phase return final_vals, phase diff --git a/qualtran/simulation/classical_sim_test.py b/qualtran/simulation/classical_sim_test.py index 582f5057a..751f1985e 100644 --- a/qualtran/simulation/classical_sim_test.py +++ b/qualtran/simulation/classical_sim_test.py @@ -13,7 +13,7 @@ # limitations under the License. import itertools -from typing import Dict +from typing import Dict, Union import networkx as nx import numpy as np @@ -25,6 +25,7 @@ Bloq, BloqBuilder, BQUInt, + CBit, LeftDangle, QAny, QBit, @@ -40,10 +41,15 @@ ) from qualtran.bloqs.basic_gates import CNOT from qualtran.simulation.classical_sim import ( + _BannedClassicalValHandler, + _FixedClassicalValHandler, add_ints, call_cbloq_classically, ClassicalSimState, + ClassicalValDistribution, + ClassicalValRetT, do_phased_classical_simulation, + MeasurementPhase, ) from qualtran.testing import execute_notebook @@ -260,3 +266,75 @@ def test_multidimensional_classical_sim_for_gqf(): y = bloq.call_classically(x=x)[0] assert isinstance(y, dtype.gf_type) np.testing.assert_equal(y, x) + + +@frozen +class ClassicalDistributionTest(Bloq): + """Bloq that outputs a random bit for testing.""" + + @property + def signature(self) -> 'Signature': + return Signature( + [Register('q', QBit(), side=Side.LEFT), Register('c', CBit(), side=Side.RIGHT)] + ) + + def on_classical_vals(self, *, q: int) -> Dict[str, ClassicalValRetT]: + return {'c': ClassicalValDistribution(2)} + + +@frozen +class ClassicalDistributionWithPhaseTest(ClassicalDistributionTest): + """Bloq that outputs a random bit with phase for testing.""" + + def basis_state_phase(self, q: int) -> Union[complex, MeasurementPhase]: + if q == 0: + return 1 + return MeasurementPhase(reg_name='c') + + +def test_classical_distribution() -> None: + bloq = ClassicalDistributionTest() + results = [bloq.call_classically(q=0)[0] for _ in range(100)] + assert all(c in {0, 1} for c in results) + assert any(c == 0 for c in results) + assert any(c == 1 for c in results) + + +def test_fixed_distribution() -> None: + cbloq = ClassicalDistributionTest().as_composite_bloq() + fixed_rng = _FixedClassicalValHandler(binst_i_to_val={0: 1}) + binst_graph = cbloq._binst_graph # pylint: disable=protected-access + results = [ + call_cbloq_classically(cbloq.signature, dict(q=0), binst_graph, random_handler=fixed_rng)[ + 0 + ]['c'] + for _ in range(100) + ] + assert all(c == 1 for c in results) + + +def test_banned_distribution() -> None: + cbloq = ClassicalDistributionTest().as_composite_bloq() + banned_rng = _BannedClassicalValHandler() + binst_graph = cbloq._binst_graph # pylint: disable=protected-access + with pytest.raises(ValueError, match='non-deterministic classical action'): + _ = call_cbloq_classically( + cbloq.signature, dict(q=0), binst_graph, random_handler=banned_rng + ) + + +def test_phased_classical_distribution(): + bloq = ClassicalDistributionWithPhaseTest() + _, phase = do_phased_classical_simulation(bloq, dict(q=0), rng=np.random.default_rng()) + assert phase == 1 + for _ in range(50): + final_vals, phase = do_phased_classical_simulation( + bloq, dict(q=1), rng=np.random.default_rng() + ) + assert phase == (-1 if final_vals['c'] else 1) + + final_values, phase = do_phased_classical_simulation( + bloq, dict(q=0), fixed_random_vals={0: 1, 1: 1} + ) + assert final_values['c'] == 1 + assert phase == 1