-
Notifications
You must be signed in to change notification settings - Fork 74
Add support for simulation over classical distributions #1682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+300
−11
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
# limitations under the License. | ||
|
||
"""Functionality for the `Bloq.call_classically(...)` protocol.""" | ||
import abc | ||
import itertools | ||
from typing import ( | ||
Any, | ||
|
@@ -28,6 +29,7 @@ | |
Union, | ||
) | ||
|
||
import attrs | ||
import networkx as nx | ||
import numpy as np | ||
import sympy | ||
|
@@ -43,13 +45,13 @@ | |
Signature, | ||
Soquet, | ||
) | ||
from qualtran._infra.composite_bloq import _binst_to_cxns | ||
from qualtran._infra.composite_bloq import _binst_to_cxns, _get_soquet | ||
|
||
if TYPE_CHECKING: | ||
from qualtran import CompositeBloq, QCDType | ||
|
||
ClassicalValT = Union[int, np.integer, NDArray[np.integer]] | ||
ClassicalValRetT = Union[int, np.integer, NDArray[np.integer]] | ||
ClassicalValRetT = Union[int, np.integer, NDArray[np.integer], 'ClassicalValDistribution'] | ||
|
||
|
||
def _numpy_dtype_from_qlt_dtype(dtype: 'QCDType') -> Type: | ||
|
@@ -106,6 +108,88 @@ def _get_in_vals( | |
return arg | ||
|
||
|
||
@attrs.frozen(hash=False) | ||
class ClassicalValDistribution: | ||
"""This class represents a distribution of classical values. | ||
|
||
Use this if the bloq has performed a measurement or other projection | ||
that has resulted in a mixed state of purely classical values. | ||
|
||
Args: | ||
a: An array of choices, or `np.arange` if an integer is given. | ||
This is the `a` parameter to `np.random.Generator.choice()`. | ||
p: An array of probabilities. If not supplied, the uniform distribution is assumed. | ||
This is the `p` parameter to `np.random.Generator.choice()`. | ||
""" | ||
|
||
a: Union[int, np.typing.ArrayLike] | ||
p: Optional[np.typing.ArrayLike] = None | ||
|
||
|
||
class _ClassicalValHandler(metaclass=abc.ABCMeta): | ||
"""An internal class for returning a random classical value. | ||
|
||
Implmentors should write the get() function which returns a random | ||
choice of values.""" | ||
|
||
@abc.abstractmethod | ||
def get(self, binst: 'BloqInstance', distribution: ClassicalValDistribution) -> Any: ... | ||
|
||
|
||
class _RandomClassicalValHandler(_ClassicalValHandler): | ||
"""Returns a random classical value using a random number generator.""" | ||
|
||
def __init__(self, rng: 'np.random.Generator'): | ||
self._gen = rng | ||
|
||
def get(self, binst, distribution: ClassicalValDistribution): | ||
return self._gen.choice(distribution.a, p=distribution.p) # type:ignore[arg-type] | ||
|
||
|
||
class _FixedClassicalValHandler(_ClassicalValHandler): | ||
"""Returns a random classical value using a fixed value per bloq instance. | ||
|
||
Useful for deterministic testing. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Args: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
Args: | ||
binst_i_to_val: mapping from BloqInstance.i instance indices | ||
to the fixed classical value. | ||
""" | ||
|
||
def __init__(self, binst_i_to_val: Dict[int, Any]): | ||
self._binst_i_to_val = binst_i_to_val | ||
|
||
def get(self, binst, distribution: ClassicalValDistribution): | ||
return self._binst_i_to_val[binst.i] | ||
|
||
|
||
class _BannedClassicalValHandler(_ClassicalValHandler): | ||
"""Used when random classical value is not able to be performed.""" | ||
|
||
def get(self, binst: 'BloqInstance', distribution: ClassicalValDistribution) -> Any: | ||
raise ValueError( | ||
f"{binst} has non-deterministic classical action." | ||
"Cannot simulate with classical values." | ||
) | ||
|
||
|
||
@attrs.frozen | ||
class MeasurementPhase: | ||
"""Sentinel value for phases based on measurement outcomes: | ||
|
||
This can be returned from `Bloq.basis_state_phase` | ||
if a phase should be applied based on a measurement outcome. | ||
This can be used in special circumstances to verify measurement-based uncomputation (MBUC). | ||
|
||
Args: | ||
reg_name: Name of the register | ||
idx: Index of the register wire(s). | ||
""" | ||
|
||
reg_name: str | ||
idx: Tuple[int, ...] = () | ||
|
||
|
||
class ClassicalSimState: | ||
"""A mutable class for classically simulating composite bloqs. | ||
|
||
|
@@ -122,6 +206,8 @@ class ClassicalSimState: | |
binst graph. | ||
vals: A mapping of input register name to classical value to serve as inputs to the | ||
procedure. | ||
random_handler: The classical random number handler to use for use in | ||
measurement-based outcomes (e.g. MBUC). | ||
|
||
Attributes: | ||
soq_assign: An assignment of soquets to classical values. We store the classical state | ||
|
@@ -138,10 +224,12 @@ def __init__( | |
signature: 'Signature', | ||
binst_graph: nx.DiGraph, | ||
vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]], | ||
random_handler: '_ClassicalValHandler' = _BannedClassicalValHandler(), | ||
): | ||
self._signature = signature | ||
self._binst_graph = binst_graph | ||
self._binst_iter = nx.topological_sort(self._binst_graph) | ||
self._random_handler = random_handler | ||
|
||
# Keep track of each soquet's bit array. Initialize with LeftDangle | ||
self.soq_assign: Dict[Soquet, ClassicalValT] = {} | ||
|
@@ -206,6 +294,9 @@ def _update_assign_from_vals( | |
|
||
else: | ||
# `val` is one value. | ||
if isinstance(val, ClassicalValDistribution): | ||
val = self._random_handler.get(binst, val) | ||
|
||
reg.dtype.assert_valid_classical_val(val, debug_str) | ||
soq = Soquet(binst, reg) | ||
self.soq_assign[soq] = val | ||
|
@@ -321,6 +412,8 @@ class directly for more fine-grained control. | |
soq_assign: An assignment of soquets to classical values. | ||
last_binst: A record of the last bloq instance we processed during simulation. | ||
phase: The current phase of the simulation state. | ||
random_handler: The classical random number handler to use for use in | ||
measurement-based outcomes (e.g. MBUC). | ||
""" | ||
|
||
def __init__( | ||
|
@@ -330,26 +423,48 @@ def __init__( | |
vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]], | ||
*, | ||
phase: complex = 1.0, | ||
random_handler: '_ClassicalValHandler', | ||
): | ||
super().__init__(signature=signature, binst_graph=binst_graph, vals=vals) | ||
super().__init__( | ||
signature=signature, binst_graph=binst_graph, vals=vals, random_handler=random_handler | ||
) | ||
_assert_valid_phase(phase) | ||
self.phase = phase | ||
|
||
@classmethod | ||
def from_cbloq( | ||
cls, cbloq: 'CompositeBloq', vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]] | ||
cls, | ||
cbloq: 'CompositeBloq', | ||
vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]], | ||
rng: Optional['np.random.Generator'] = None, | ||
fixed_random_vals: Optional[Dict[int, Any]] = None, | ||
) -> 'PhasedClassicalSimState': | ||
"""Initiate a classical simulation from a CompositeBloq. | ||
|
||
Args: | ||
cbloq: The composite bloq | ||
vals: A mapping of input register name to classical value to serve as inputs to the | ||
procedure. | ||
rng: A random number generator to use for classical random values, such a np.random. | ||
fixed_random_vals: A dictionary of bloq instances to values to perform fixed calculation | ||
for classical values. | ||
|
||
Returns: | ||
A new classical sim state. | ||
""" | ||
return cls(signature=cbloq.signature, binst_graph=cbloq._binst_graph, vals=vals) | ||
rnd_handler: _ClassicalValHandler | ||
if rng is not None: | ||
rnd_handler = _RandomClassicalValHandler(rng=rng) | ||
elif fixed_random_vals is not None: | ||
rnd_handler = _FixedClassicalValHandler(binst_i_to_val=fixed_random_vals) | ||
else: | ||
rnd_handler = _BannedClassicalValHandler() | ||
return cls( | ||
signature=cbloq.signature, | ||
binst_graph=cbloq._binst_graph, | ||
vals=vals, | ||
random_handler=rnd_handler, | ||
) | ||
|
||
def _binst_basis_state_phase(self, binst, in_vals): | ||
"""Call `basis_state_phase` on a given bloq instance. | ||
|
@@ -359,7 +474,26 @@ def _binst_basis_state_phase(self, binst, in_vals): | |
""" | ||
bloq = binst.bloq | ||
bloq_phase = bloq.basis_state_phase(**in_vals) | ||
if bloq_phase is not None: | ||
if isinstance(bloq_phase, MeasurementPhase): | ||
# In this special case, there is a coupling between the classical result and the | ||
# phase result (because the classical result is stochastic). We look up the measurement | ||
# result and apply a phase if it is `1`. | ||
meas_result = self.soq_assign[ | ||
_get_soquet( | ||
binst=binst, | ||
reg_name=bloq_phase.reg_name, | ||
right=True, | ||
idx=bloq_phase.idx, | ||
binst_graph=self._binst_graph, | ||
) | ||
] | ||
if meas_result == 1: | ||
# Measurement result of 1, phase of -1 | ||
self.phase *= -1.0 | ||
else: | ||
# Measurement result of 0, phase of +1 | ||
pass | ||
elif bloq_phase is not None: | ||
_assert_valid_phase(bloq_phase) | ||
self.phase *= bloq_phase | ||
else: | ||
|
@@ -371,6 +505,9 @@ def call_cbloq_classically( | |
signature: Signature, | ||
vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]], | ||
binst_graph: nx.DiGraph, | ||
random_handler: '_ClassicalValHandler' = _RandomClassicalValHandler( | ||
rng=np.random.default_rng() | ||
), | ||
) -> Tuple[Dict[str, ClassicalValT], Dict[Soquet, ClassicalValT]]: | ||
"""Propagate `on_classical_vals` calls through a composite bloq's contents. | ||
|
||
|
@@ -381,14 +518,16 @@ def call_cbloq_classically( | |
signature: The cbloq's signature for validating inputs | ||
vals: Mapping from register name to classical values | ||
binst_graph: The cbloq's binst graph. | ||
random_handler: The classical random number handler to use for use in | ||
measurement-based outcomes (e.g. MBUC). | ||
|
||
Returns: | ||
final_vals: A mapping from register name to output classical values | ||
soq_assign: An assignment from each soquet to its classical value. Soquets | ||
corresponding to thru registers will be mapped to the *output* classical | ||
value. | ||
""" | ||
sim = ClassicalSimState(signature, binst_graph, vals) | ||
sim = ClassicalSimState(signature, binst_graph, vals, random_handler) | ||
final_vals = sim.simulate() | ||
return final_vals, sim.soq_assign | ||
|
||
|
@@ -398,7 +537,12 @@ def _assert_valid_phase(p: complex, atol: float = 1e-8): | |
raise ValueError(f"Phases must have unit modulus. Found {p}.") | ||
|
||
|
||
def do_phased_classical_simulation(bloq: 'Bloq', vals: Mapping[str, 'ClassicalValT']): | ||
def do_phased_classical_simulation( | ||
bloq: 'Bloq', | ||
vals: Mapping[str, 'ClassicalValT'], | ||
rng: Optional['np.random.Generator'] = None, | ||
fixed_random_vals: Optional[Dict[int, Any]] = None, | ||
): | ||
"""Do a phased classical simulation of the bloq. | ||
|
||
This provides a simple interface to `PhasedClassicalSimState`. Advanced users | ||
|
@@ -408,13 +552,20 @@ def do_phased_classical_simulation(bloq: 'Bloq', vals: Mapping[str, 'ClassicalVa | |
bloq: The bloq to simulate | ||
vals: A mapping from input register name to initial classical values. The initial phase is | ||
assumed to be 1.0. | ||
rng: A numpy random generator (e.g. from `np.random.default_rng()`). This function | ||
will use this generator to supply random values from certain phased-classical operations | ||
like `MeasX`. If not supplied, classical measurements will use a random value. | ||
fixed_random_vals: A dictionary of instance to values to perform fixed calculation | ||
for classical values. | ||
|
||
Returns: | ||
final_vals: A mapping of output register name to final classical values. | ||
phase: The final phase. | ||
""" | ||
cbloq = bloq.as_composite_bloq() | ||
sim = PhasedClassicalSimState.from_cbloq(cbloq, vals=vals) | ||
sim = PhasedClassicalSimState.from_cbloq( | ||
cbloq, vals=vals, rng=rng, fixed_random_vals=fixed_random_vals | ||
) | ||
final_vals = sim.simulate() | ||
phase = sim.phase | ||
return final_vals, phase | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you!