From 1ea89feaa9fa4f61bf627a3ecb2c7246b82265c9 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Thu, 12 Jun 2025 06:05:03 -0700 Subject: [PATCH 01/22] add changes to apply_unitary and unitary to be able to take in numpy array input --- .../cirq/protocols/apply_mixture_protocol.py | 28 ------------------- .../cirq/protocols/apply_unitary_protocol.py | 23 +++++++++------ .../protocols/apply_unitary_protocol_test.py | 3 +- cirq-core/cirq/protocols/kraus_protocol.py | 2 +- cirq-core/cirq/protocols/mixture_protocol.py | 2 +- cirq-core/cirq/protocols/unitary_protocol.py | 4 +++ .../cirq/protocols/unitary_protocol_test.py | 4 +++ 7 files changed, 26 insertions(+), 40 deletions(-) diff --git a/cirq-core/cirq/protocols/apply_mixture_protocol.py b/cirq-core/cirq/protocols/apply_mixture_protocol.py index 2feb3cc659c..1f6b4f74853 100644 --- a/cirq-core/cirq/protocols/apply_mixture_protocol.py +++ b/cirq-core/cirq/protocols/apply_mixture_protocol.py @@ -332,32 +332,6 @@ def _apply_unitary_strat( return right_result -def _apply_unitary_from_matrix_strat( - val: np.ndarray, args: ApplyMixtureArgs, is_density_matrix: bool -) -> np.ndarray | None: - """Used to enact mixture tuples that are given as (probability, np.ndarray) - - If `val` does not support `apply_unitary` returns None. - """ - qid_shape = tuple(args.target_tensor.shape[i] for i in args.left_axes) - matrix_tensor = np.reshape(val.astype(args.target_tensor.dtype), qid_shape * 2) - linalg.targeted_left_multiply( - matrix_tensor, args.target_tensor, args.left_axes, out=args.auxiliary_buffer0 - ) - - if not is_density_matrix: - return args.auxiliary_buffer0 - # No need to transpose as we are acting on the tensor - # representation of matrix, so transpose is done for us. - linalg.targeted_left_multiply( - np.conjugate(matrix_tensor), - args.auxiliary_buffer0, - cast(tuple[int], args.right_axes), - out=args.target_tensor, - ) - return args.target_tensor - - def _apply_mixture_from_mixture_strat( val: Any, args: ApplyMixtureArgs, is_density_matrix: bool ) -> np.ndarray | None: @@ -373,8 +347,6 @@ def _apply_mixture_from_mixture_strat( for prob, op in prob_mix: np.copyto(dst=args.target_tensor, src=args.auxiliary_buffer1) right_result = _apply_unitary_strat(op, args, is_density_matrix) - if right_result is None: - right_result = _apply_unitary_from_matrix_strat(op, args, is_density_matrix) args.out_buffer += prob * right_result diff --git a/cirq-core/cirq/protocols/apply_unitary_protocol.py b/cirq-core/cirq/protocols/apply_unitary_protocol.py index 9aeb37ad509..5791ff6f088 100644 --- a/cirq-core/cirq/protocols/apply_unitary_protocol.py +++ b/cirq-core/cirq/protocols/apply_unitary_protocol.py @@ -469,15 +469,20 @@ def _apply_unitary_from_matrix(matrix: np.ndarray, unitary_value: Any, args: App def _strat_apply_unitary_from_unitary( unitary_value: Any, args: ApplyUnitaryArgs ) -> np.ndarray | None: - # Check for magic method. - method = getattr(unitary_value, '_unitary_', None) - if method is None: - return NotImplemented - - # Attempt to get the unitary matrix. - matrix = method() - if matrix is NotImplemented or matrix is None: - return matrix + if isinstance(unitary_value, np.ndarray): + matrix = unitary_value + if not linalg.is_unitary(matrix): + return None + else: + # Check for magic method. + method = getattr(unitary_value, '_unitary_', None) + if method is None: + return NotImplemented + + # Attempt to get the unitary matrix. + matrix = method() + if matrix is NotImplemented or matrix is None: + return matrix return _apply_unitary_from_matrix(matrix, unitary_value, args) diff --git a/cirq-core/cirq/protocols/apply_unitary_protocol_test.py b/cirq-core/cirq/protocols/apply_unitary_protocol_test.py index 8f2acc08278..94b3ee58733 100644 --- a/cirq-core/cirq/protocols/apply_unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/apply_unitary_protocol_test.py @@ -56,12 +56,13 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: args.target_tensor[one] *= -1 return args.target_tensor - fails = [NoUnitaryEffect(), HasApplyReturnsNotImplemented()] + fails = [NoUnitaryEffect(), HasApplyReturnsNotImplemented(), m * 2] passes = [ HasUnitary(), HasApplyReturnsNotImplementedButHasUnitary(), HasApplyOutputInBuffer(), HasApplyMutateInline(), + m, ] def make_input(): diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 09af067d9f6..00b49b6467a 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -149,7 +149,7 @@ def kraus( mixture_result = NotImplemented if mixture_getter is None else mixture_getter() if mixture_result is not NotImplemented and mixture_result is not None: return tuple( - np.sqrt(p) * (u if isinstance(u, np.ndarray) else unitary(u)) for p, u in mixture_result + np.sqrt(p) * unitary(u) for p, u in mixture_result ) unitary_getter = getattr(val, '_unitary_', None) diff --git a/cirq-core/cirq/protocols/mixture_protocol.py b/cirq-core/cirq/protocols/mixture_protocol.py index 5807a96ee29..3139fd4c4a3 100644 --- a/cirq-core/cirq/protocols/mixture_protocol.py +++ b/cirq-core/cirq/protocols/mixture_protocol.py @@ -94,7 +94,7 @@ def mixture( mixture_getter = getattr(val, '_mixture_', None) result = NotImplemented if mixture_getter is None else mixture_getter() if result is not NotImplemented and result is not None: - return tuple((p, u if isinstance(u, np.ndarray) else unitary(u)) for p, u in result) + return tuple((p, unitary(u)) for p, u in result) unitary_getter = getattr(val, '_unitary_', None) result = NotImplemented if unitary_getter is None else unitary_getter() diff --git a/cirq-core/cirq/protocols/unitary_protocol.py b/cirq-core/cirq/protocols/unitary_protocol.py index 87fb3576728..1af3e80da80 100644 --- a/cirq-core/cirq/protocols/unitary_protocol.py +++ b/cirq-core/cirq/protocols/unitary_protocol.py @@ -84,6 +84,7 @@ def unitary( The matrix is determined by any one of the following techniques: + - If the value is a numpy array, it is returned directly. - The value has a `_unitary_` method that returns something besides None or NotImplemented. The matrix is whatever the method returned. - The value has a `_decompose_` method that returns a list of operations, @@ -111,6 +112,9 @@ def unitary( TypeError: `val` doesn't have a unitary effect and no default value was specified. """ + if isinstance(val, np.ndarray): + return val + strats = [ _strat_unitary_from_unitary, _strat_unitary_from_apply_unitary, diff --git a/cirq-core/cirq/protocols/unitary_protocol_test.py b/cirq-core/cirq/protocols/unitary_protocol_test.py index b5ac97a83c7..0ba235e0c6b 100644 --- a/cirq-core/cirq/protocols/unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/unitary_protocol_test.py @@ -161,6 +161,10 @@ def test_unitary(): _ = cirq.unitary(ReturnsNotImplemented()) assert cirq.unitary(ReturnsMatrix()) is m1 + # Test that numpy arrays are handled directly + test_matrix = np.array([[1, 0], [0, 1]]) + assert cirq.unitary(test_matrix, NotImplemented) is test_matrix + assert cirq.unitary(NoMethod(), None) is None assert cirq.unitary(ReturnsNotImplemented(), None) is None assert cirq.unitary(ReturnsMatrix(), None) is m1 From 3ec6d676b2c29461b53e69f83549bf3cffa0bcb9 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Thu, 12 Jun 2025 06:05:03 -0700 Subject: [PATCH 02/22] add changes to apply_unitary and unitary to be able to take in numpy array input --- .../cirq/protocols/apply_mixture_protocol.py | 29 ------------------- .../cirq/protocols/apply_unitary_protocol.py | 23 +++++++++------ .../protocols/apply_unitary_protocol_test.py | 3 +- cirq-core/cirq/protocols/kraus_protocol.py | 4 +-- cirq-core/cirq/protocols/mixture_protocol.py | 2 +- cirq-core/cirq/protocols/unitary_protocol.py | 4 +++ .../cirq/protocols/unitary_protocol_test.py | 4 +++ 7 files changed, 26 insertions(+), 43 deletions(-) diff --git a/cirq-core/cirq/protocols/apply_mixture_protocol.py b/cirq-core/cirq/protocols/apply_mixture_protocol.py index 2feb3cc659c..6a064b50881 100644 --- a/cirq-core/cirq/protocols/apply_mixture_protocol.py +++ b/cirq-core/cirq/protocols/apply_mixture_protocol.py @@ -22,7 +22,6 @@ import numpy as np from typing_extensions import Protocol -from cirq import linalg from cirq._doc import doc_private from cirq.protocols import qid_shape_protocol from cirq.protocols.apply_unitary_protocol import apply_unitary, ApplyUnitaryArgs @@ -332,32 +331,6 @@ def _apply_unitary_strat( return right_result -def _apply_unitary_from_matrix_strat( - val: np.ndarray, args: ApplyMixtureArgs, is_density_matrix: bool -) -> np.ndarray | None: - """Used to enact mixture tuples that are given as (probability, np.ndarray) - - If `val` does not support `apply_unitary` returns None. - """ - qid_shape = tuple(args.target_tensor.shape[i] for i in args.left_axes) - matrix_tensor = np.reshape(val.astype(args.target_tensor.dtype), qid_shape * 2) - linalg.targeted_left_multiply( - matrix_tensor, args.target_tensor, args.left_axes, out=args.auxiliary_buffer0 - ) - - if not is_density_matrix: - return args.auxiliary_buffer0 - # No need to transpose as we are acting on the tensor - # representation of matrix, so transpose is done for us. - linalg.targeted_left_multiply( - np.conjugate(matrix_tensor), - args.auxiliary_buffer0, - cast(tuple[int], args.right_axes), - out=args.target_tensor, - ) - return args.target_tensor - - def _apply_mixture_from_mixture_strat( val: Any, args: ApplyMixtureArgs, is_density_matrix: bool ) -> np.ndarray | None: @@ -373,8 +346,6 @@ def _apply_mixture_from_mixture_strat( for prob, op in prob_mix: np.copyto(dst=args.target_tensor, src=args.auxiliary_buffer1) right_result = _apply_unitary_strat(op, args, is_density_matrix) - if right_result is None: - right_result = _apply_unitary_from_matrix_strat(op, args, is_density_matrix) args.out_buffer += prob * right_result diff --git a/cirq-core/cirq/protocols/apply_unitary_protocol.py b/cirq-core/cirq/protocols/apply_unitary_protocol.py index 9aeb37ad509..5791ff6f088 100644 --- a/cirq-core/cirq/protocols/apply_unitary_protocol.py +++ b/cirq-core/cirq/protocols/apply_unitary_protocol.py @@ -469,15 +469,20 @@ def _apply_unitary_from_matrix(matrix: np.ndarray, unitary_value: Any, args: App def _strat_apply_unitary_from_unitary( unitary_value: Any, args: ApplyUnitaryArgs ) -> np.ndarray | None: - # Check for magic method. - method = getattr(unitary_value, '_unitary_', None) - if method is None: - return NotImplemented - - # Attempt to get the unitary matrix. - matrix = method() - if matrix is NotImplemented or matrix is None: - return matrix + if isinstance(unitary_value, np.ndarray): + matrix = unitary_value + if not linalg.is_unitary(matrix): + return None + else: + # Check for magic method. + method = getattr(unitary_value, '_unitary_', None) + if method is None: + return NotImplemented + + # Attempt to get the unitary matrix. + matrix = method() + if matrix is NotImplemented or matrix is None: + return matrix return _apply_unitary_from_matrix(matrix, unitary_value, args) diff --git a/cirq-core/cirq/protocols/apply_unitary_protocol_test.py b/cirq-core/cirq/protocols/apply_unitary_protocol_test.py index 8f2acc08278..94b3ee58733 100644 --- a/cirq-core/cirq/protocols/apply_unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/apply_unitary_protocol_test.py @@ -56,12 +56,13 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: args.target_tensor[one] *= -1 return args.target_tensor - fails = [NoUnitaryEffect(), HasApplyReturnsNotImplemented()] + fails = [NoUnitaryEffect(), HasApplyReturnsNotImplemented(), m * 2] passes = [ HasUnitary(), HasApplyReturnsNotImplementedButHasUnitary(), HasApplyOutputInBuffer(), HasApplyMutateInline(), + m, ] def make_input(): diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 09af067d9f6..32f0da9d1cf 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -148,9 +148,7 @@ def kraus( mixture_getter = getattr(val, '_mixture_', None) mixture_result = NotImplemented if mixture_getter is None else mixture_getter() if mixture_result is not NotImplemented and mixture_result is not None: - return tuple( - np.sqrt(p) * (u if isinstance(u, np.ndarray) else unitary(u)) for p, u in mixture_result - ) + return tuple(np.sqrt(p) * unitary(u) for p, u in mixture_result) unitary_getter = getattr(val, '_unitary_', None) unitary_result = NotImplemented if unitary_getter is None else unitary_getter() diff --git a/cirq-core/cirq/protocols/mixture_protocol.py b/cirq-core/cirq/protocols/mixture_protocol.py index 5807a96ee29..3139fd4c4a3 100644 --- a/cirq-core/cirq/protocols/mixture_protocol.py +++ b/cirq-core/cirq/protocols/mixture_protocol.py @@ -94,7 +94,7 @@ def mixture( mixture_getter = getattr(val, '_mixture_', None) result = NotImplemented if mixture_getter is None else mixture_getter() if result is not NotImplemented and result is not None: - return tuple((p, u if isinstance(u, np.ndarray) else unitary(u)) for p, u in result) + return tuple((p, unitary(u)) for p, u in result) unitary_getter = getattr(val, '_unitary_', None) result = NotImplemented if unitary_getter is None else unitary_getter() diff --git a/cirq-core/cirq/protocols/unitary_protocol.py b/cirq-core/cirq/protocols/unitary_protocol.py index 87fb3576728..1af3e80da80 100644 --- a/cirq-core/cirq/protocols/unitary_protocol.py +++ b/cirq-core/cirq/protocols/unitary_protocol.py @@ -84,6 +84,7 @@ def unitary( The matrix is determined by any one of the following techniques: + - If the value is a numpy array, it is returned directly. - The value has a `_unitary_` method that returns something besides None or NotImplemented. The matrix is whatever the method returned. - The value has a `_decompose_` method that returns a list of operations, @@ -111,6 +112,9 @@ def unitary( TypeError: `val` doesn't have a unitary effect and no default value was specified. """ + if isinstance(val, np.ndarray): + return val + strats = [ _strat_unitary_from_unitary, _strat_unitary_from_apply_unitary, diff --git a/cirq-core/cirq/protocols/unitary_protocol_test.py b/cirq-core/cirq/protocols/unitary_protocol_test.py index b5ac97a83c7..0ba235e0c6b 100644 --- a/cirq-core/cirq/protocols/unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/unitary_protocol_test.py @@ -161,6 +161,10 @@ def test_unitary(): _ = cirq.unitary(ReturnsNotImplemented()) assert cirq.unitary(ReturnsMatrix()) is m1 + # Test that numpy arrays are handled directly + test_matrix = np.array([[1, 0], [0, 1]]) + assert cirq.unitary(test_matrix, NotImplemented) is test_matrix + assert cirq.unitary(NoMethod(), None) is None assert cirq.unitary(ReturnsNotImplemented(), None) is None assert cirq.unitary(ReturnsMatrix(), None) is m1 From 381f1057cb7d8302dcc8905842246f9734f7a2bc Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Thu, 12 Jun 2025 07:13:30 -0700 Subject: [PATCH 03/22] handle nd array in has_unitary and address pr comment --- cirq-core/cirq/protocols/has_unitary_protocol.py | 4 +++- cirq-core/cirq/protocols/has_unitary_protocol_test.py | 3 +++ cirq-core/cirq/protocols/unitary_protocol.py | 3 +++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/protocols/has_unitary_protocol.py b/cirq-core/cirq/protocols/has_unitary_protocol.py index 876ea7da4ee..a203a809f2c 100644 --- a/cirq-core/cirq/protocols/has_unitary_protocol.py +++ b/cirq-core/cirq/protocols/has_unitary_protocol.py @@ -19,7 +19,7 @@ import numpy as np from typing_extensions import Protocol -from cirq import qis +from cirq import linalg, qis from cirq._doc import doc_private from cirq.protocols import qid_shape_protocol from cirq.protocols.apply_unitary_protocol import ApplyUnitaryArgs @@ -112,6 +112,8 @@ def has_unitary(val: Any, *, allow_decompose: bool = True) -> bool: def _strat_has_unitary_from_has_unitary(val: Any) -> bool | None: """Attempts to infer a value's unitary-ness via its _has_unitary_ method.""" + if isinstance(val, np.ndarray): + return linalg.is_unitary(val) if hasattr(val, '_has_unitary_'): result = val._has_unitary_() if result is NotImplemented: diff --git a/cirq-core/cirq/protocols/has_unitary_protocol_test.py b/cirq-core/cirq/protocols/has_unitary_protocol_test.py index 60595bd2e98..b29023edbe3 100644 --- a/cirq-core/cirq/protocols/has_unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/has_unitary_protocol_test.py @@ -61,10 +61,13 @@ class Yes: def _unitary_(self): return np.array([[1]]) + m = np.diag([1, -1]) assert not cirq.has_unitary(No1()) assert not cirq.has_unitary(No2()) + assert not cirq.has_unitary(m * 2) assert cirq.has_unitary(Yes()) assert cirq.has_unitary(Yes(), allow_decompose=False) + assert cirq.has_unitary(m) def test_via_apply_unitary() -> None: diff --git a/cirq-core/cirq/protocols/unitary_protocol.py b/cirq-core/cirq/protocols/unitary_protocol.py index 1af3e80da80..5169af943a5 100644 --- a/cirq-core/cirq/protocols/unitary_protocol.py +++ b/cirq-core/cirq/protocols/unitary_protocol.py @@ -20,6 +20,7 @@ import numpy as np from typing_extensions import Protocol +from cirq import linalg from cirq._doc import doc_private from cirq.protocols import qid_shape_protocol from cirq.protocols.apply_unitary_protocol import apply_unitaries, ApplyUnitaryArgs @@ -113,6 +114,8 @@ def unitary( specified. """ if isinstance(val, np.ndarray): + if not linalg.is_unitary(val): + raise ValueError("The provided numpy array is not unitary.") return val strats = [ From 5ed34efe13a8ba493f939a7969fe316710bc475b Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Thu, 12 Jun 2025 07:50:03 -0700 Subject: [PATCH 04/22] fix related tests --- cirq-core/cirq/protocols/mixture_protocol_test.py | 4 ++-- .../analytical_decompositions/two_qubit_to_sycamore_test.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/protocols/mixture_protocol_test.py b/cirq-core/cirq/protocols/mixture_protocol_test.py index 9d37ca697b0..0d2d3872d42 100644 --- a/cirq-core/cirq/protocols/mixture_protocol_test.py +++ b/cirq-core/cirq/protocols/mixture_protocol_test.py @@ -19,8 +19,8 @@ import cirq -a = np.array([1]) -b = np.array([1j]) +a = np.eye(2) +b = np.eye(2) class NoMethod: diff --git a/cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py b/cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py index 9ad990f3ea0..956635a6554 100644 --- a/cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py +++ b/cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py @@ -86,7 +86,7 @@ def test_known_two_qubit_op_decomposition(op, theta_range): cirq.FSimGate(0.25, 0.85).on(*_QUBITS), cirq.XX(*_QUBITS), cirq.YY(*_QUBITS), - *[cirq.testing.random_unitary(4, random_state=1234) for _ in range(10)], + *[cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)).on(*_QUBITS) for _ in range(10)], ], ) def test_unknown_two_qubit_op_decomposition(op): From 2044d0a447860c36bae5100096311406d3492ad7 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Thu, 12 Jun 2025 08:06:25 -0700 Subject: [PATCH 05/22] lint formatting etc --- cirq-core/cirq/protocols/pauli_expansion_protocol_test.py | 2 +- cirq-core/cirq/protocols/unitary_protocol.py | 1 + .../analytical_decompositions/two_qubit_to_sycamore_test.py | 5 ++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/protocols/pauli_expansion_protocol_test.py b/cirq-core/cirq/protocols/pauli_expansion_protocol_test.py index ae7916ce226..be515e9b2a4 100644 --- a/cirq-core/cirq/protocols/pauli_expansion_protocol_test.py +++ b/cirq-core/cirq/protocols/pauli_expansion_protocol_test.py @@ -54,7 +54,7 @@ def _unitary_(self) -> np.ndarray: @pytest.mark.parametrize( - 'val', (NoMethod(), ReturnsNotImplemented(), HasQuditUnitary(), 123, np.eye(2), object(), cirq) + 'val', (NoMethod(), ReturnsNotImplemented(), HasQuditUnitary(), 123, object(), cirq) ) def test_raises_no_pauli_expansion(val) -> None: assert cirq.pauli_expansion(val, default=None) is None diff --git a/cirq-core/cirq/protocols/unitary_protocol.py b/cirq-core/cirq/protocols/unitary_protocol.py index 5169af943a5..0e77947267f 100644 --- a/cirq-core/cirq/protocols/unitary_protocol.py +++ b/cirq-core/cirq/protocols/unitary_protocol.py @@ -112,6 +112,7 @@ def unitary( Raises: TypeError: `val` doesn't have a unitary effect and no default value was specified. + ValueError: `val` is a numpy array that is not unitary. """ if isinstance(val, np.ndarray): if not linalg.is_unitary(val): diff --git a/cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py b/cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py index 956635a6554..2e1b1291835 100644 --- a/cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py +++ b/cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py @@ -86,7 +86,10 @@ def test_known_two_qubit_op_decomposition(op, theta_range): cirq.FSimGate(0.25, 0.85).on(*_QUBITS), cirq.XX(*_QUBITS), cirq.YY(*_QUBITS), - *[cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)).on(*_QUBITS) for _ in range(10)], + *[ + cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)).on(*_QUBITS) + for _ in range(10) + ], ], ) def test_unknown_two_qubit_op_decomposition(op): From 055245c671966918cb917b2ad969cc3a6925bbd6 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Thu, 12 Jun 2025 09:19:13 -0700 Subject: [PATCH 06/22] add coverage --- cirq-core/cirq/protocols/unitary_protocol_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cirq-core/cirq/protocols/unitary_protocol_test.py b/cirq-core/cirq/protocols/unitary_protocol_test.py index 0ba235e0c6b..3ea3ed4465a 100644 --- a/cirq-core/cirq/protocols/unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/unitary_protocol_test.py @@ -165,6 +165,11 @@ def test_unitary(): test_matrix = np.array([[1, 0], [0, 1]]) assert cirq.unitary(test_matrix, NotImplemented) is test_matrix + # Test that non-unitary numpy arrays raise ValueError + non_unitary_matrix = np.array([[1, 1], [0, 1]]) + with pytest.raises(ValueError, match="The provided numpy array is not unitary"): + _ = cirq.unitary(non_unitary_matrix) + assert cirq.unitary(NoMethod(), None) is None assert cirq.unitary(ReturnsNotImplemented(), None) is None assert cirq.unitary(ReturnsMatrix(), None) is m1 From 175c719ee4fee1bee55757d8984c9e4e8c4fe169 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Thu, 12 Jun 2025 16:36:48 -0700 Subject: [PATCH 07/22] fix test case --- cirq-core/cirq/protocols/mixture_protocol_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/protocols/mixture_protocol_test.py b/cirq-core/cirq/protocols/mixture_protocol_test.py index 0d2d3872d42..ac6abbf7dfe 100644 --- a/cirq-core/cirq/protocols/mixture_protocol_test.py +++ b/cirq-core/cirq/protocols/mixture_protocol_test.py @@ -19,8 +19,8 @@ import cirq -a = np.eye(2) -b = np.eye(2) +a = np.array([[1]]) +b = np.array([[1j]]) class NoMethod: From 1ddd07cfe1c8250973e60702d80156bcbbda5478 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Wed, 18 Jun 2025 01:03:33 -0700 Subject: [PATCH 08/22] update sycamore test based on pr comment --- .../two_qubit_to_sycamore_test.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py b/cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py index 2e1b1291835..485110979b3 100644 --- a/cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py +++ b/cirq-google/cirq_google/transformers/analytical_decompositions/two_qubit_to_sycamore_test.py @@ -86,15 +86,12 @@ def test_known_two_qubit_op_decomposition(op, theta_range): cirq.FSimGate(0.25, 0.85).on(*_QUBITS), cirq.XX(*_QUBITS), cirq.YY(*_QUBITS), - *[ - cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)).on(*_QUBITS) - for _ in range(10) - ], + cirq.MatrixGate(cirq.testing.random_unitary(4)).on(*_QUBITS), ], ) def test_unknown_two_qubit_op_decomposition(op): assert cg.known_2q_op_to_sycamore_operations(op) is None - if cirq.has_unitary(op) and cirq.num_qubits(op) == 2: + if not cirq.is_parameterized(op) and cirq.num_qubits(op) == 2: matrix_2q_circuit = cirq.Circuit( cg.two_qubit_matrix_to_sycamore_operations(_QUBITS[0], _QUBITS[1], cirq.unitary(op)) ) From 021f09b8a770e5238a9a5579f3bfa618d588d35b Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Tue, 24 Jun 2025 22:53:50 +0900 Subject: [PATCH 09/22] update kraus protocol to add apply_channel fallback --- cirq-core/cirq/protocols/kraus_protocol.py | 35 +++++++++++++++++++ .../cirq/protocols/kraus_protocol_test.py | 27 ++++++++++++++ cirq-core/cirq/testing/circuit_compare.py | 3 ++ 3 files changed, 65 insertions(+) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 32f0da9d1cf..ffd33c503e4 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -23,6 +23,7 @@ import numpy as np from typing_extensions import Protocol +from cirq import protocols, qis from cirq._doc import doc_private from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits from cirq.protocols.mixture_protocol import has_mixture @@ -94,6 +95,35 @@ def _has_kraus_(self) -> bool: """ +def _strat_kraus_from_apply_channel(val: Any) -> tuple[np.ndarray, ...] | None: + """Attempts to compute a value's Kraus operators via its _apply_channel_ method. + This is very expensive (O(16^N)), so only do this as a last resort.""" + method = getattr(val, '_apply_channel_', None) + if method is None: + return NotImplemented + + qid_shape = protocols.qid_shape(val) + + eye = qis.eye_tensor(qid_shape * 2, dtype=np.complex128) + superop = protocols.apply_channel( + val=val, + args=protocols.ApplyChannelArgs( + target_tensor=eye, + out_buffer=np.ones_like(eye) * float('nan'), + auxiliary_buffer0=np.ones_like(eye) * float('nan'), + auxiliary_buffer1=np.ones_like(eye) * float('nan'), + left_axes=list(range(len(qid_shape))), + right_axes=list(range(len(qid_shape), len(qid_shape) * 2)), + ), + default=None, + ) + if superop is None or superop is NotImplemented: + return NotImplemented + n = np.prod(qid_shape) ** 2 + kraus_ops = qis.superoperator_to_kraus(superop.reshape((n, n))) + return tuple(kraus_ops) + + def kraus( val: Any, default: Any = RaiseTypeErrorIfNotProvided ) -> tuple[np.ndarray, ...] | TDefault: @@ -162,6 +192,11 @@ def kraus( if default is not RaiseTypeErrorIfNotProvided: return default + # Last-resort fallback: try to derive Kraus from _apply_channel_ + result = _strat_kraus_from_apply_channel(val) + if result is not NotImplemented: + return result + if kraus_getter is None and unitary_getter is None and mixture_getter is None: raise TypeError( f"object of type '{type(val)}' has no _kraus_ or _mixture_ or _unitary_ method." diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 7ce4dd900fd..57ebbc7d0e0 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -171,3 +171,30 @@ def test_has_kraus_when_decomposed(decomposed_cls) -> None: op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit('test')) assert cirq.has_kraus(op) assert not cirq.has_kraus(op, allow_decompose=False) + + +def test_kraus_fallback_to_apply_channel() -> None: + """Kraus protocol falls back to _apply_channel_ when no _kraus_, _mixture_, or _unitary_.""" + p = 0.5 + K0 = np.sqrt(1 - p) * np.eye(2) + K1 = np.sqrt(p) * np.array([[0, 1], [1, 0]]) + expected_kraus = (K0, K1) + + class BitFlipChannel: + def _num_qubits_(self): + return 1 + + def _apply_channel_(self, args: cirq.ApplyChannelArgs): + X = np.array([[0, 1], [1, 0]], dtype=np.complex128) + rho = args.target_tensor + out = (1 - p) * rho + p * X @ rho @ X + args.out_buffer[...] = out + return args.out_buffer + + chan = BitFlipChannel() + kraus_ops = cirq.kraus(chan) + + # Compare the superoperator matrices for equivalence + expected_super = sum(np.kron(k, k.conj()) for k in expected_kraus) + actual_super = sum(np.kron(k, k.conj()) for k in kraus_ops) + np.testing.assert_allclose(actual_super, expected_super, atol=1e-8) diff --git a/cirq-core/cirq/testing/circuit_compare.py b/cirq-core/cirq/testing/circuit_compare.py index 3f927a5579f..f5548aa1849 100644 --- a/cirq-core/cirq/testing/circuit_compare.py +++ b/cirq-core/cirq/testing/circuit_compare.py @@ -336,6 +336,9 @@ def assert_has_consistent_apply_channel(val: Any, *, atol: float = 1e-8) -> None atol: Absolute error tolerance. """ __tracebackhide__ = True + method = getattr(val, '_apply_channel_', None) + if method is None: + return False kraus = protocols.kraus(val, default=None) expected = qis.kraus_to_superoperator(kraus) if kraus is not None else None From fdcd59526d0d4edf34402ce126eda473aa64cf4e Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Tue, 24 Jun 2025 23:04:22 +0900 Subject: [PATCH 10/22] . --- cirq-core/cirq/testing/circuit_compare.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cirq-core/cirq/testing/circuit_compare.py b/cirq-core/cirq/testing/circuit_compare.py index f5548aa1849..ec45f7c0ed0 100644 --- a/cirq-core/cirq/testing/circuit_compare.py +++ b/cirq-core/cirq/testing/circuit_compare.py @@ -336,9 +336,7 @@ def assert_has_consistent_apply_channel(val: Any, *, atol: float = 1e-8) -> None atol: Absolute error tolerance. """ __tracebackhide__ = True - method = getattr(val, '_apply_channel_', None) - if method is None: - return False + assert hasattr(val, '_apply_channel_', None) kraus = protocols.kraus(val, default=None) expected = qis.kraus_to_superoperator(kraus) if kraus is not None else None From 3bf33eb436d27d88b757c43d0994202452ba3aa2 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Tue, 24 Jun 2025 23:11:05 +0900 Subject: [PATCH 11/22] . --- cirq-core/cirq/testing/circuit_compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/testing/circuit_compare.py b/cirq-core/cirq/testing/circuit_compare.py index ec45f7c0ed0..ea0870d0bfb 100644 --- a/cirq-core/cirq/testing/circuit_compare.py +++ b/cirq-core/cirq/testing/circuit_compare.py @@ -336,7 +336,7 @@ def assert_has_consistent_apply_channel(val: Any, *, atol: float = 1e-8) -> None atol: Absolute error tolerance. """ __tracebackhide__ = True - assert hasattr(val, '_apply_channel_', None) + assert hasattr(val, '_apply_channel_') kraus = protocols.kraus(val, default=None) expected = qis.kraus_to_superoperator(kraus) if kraus is not None else None From 69c201446c5cf1c7ed831ab7930fe5c85ba2cc49 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Tue, 24 Jun 2025 23:27:54 +0900 Subject: [PATCH 12/22] . --- cirq-core/cirq/protocols/kraus_protocol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index ffd33c503e4..e3b90c50812 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -100,7 +100,7 @@ def _strat_kraus_from_apply_channel(val: Any) -> tuple[np.ndarray, ...] | None: This is very expensive (O(16^N)), so only do this as a last resort.""" method = getattr(val, '_apply_channel_', None) if method is None: - return NotImplemented + return None qid_shape = protocols.qid_shape(val) @@ -118,7 +118,7 @@ def _strat_kraus_from_apply_channel(val: Any) -> tuple[np.ndarray, ...] | None: default=None, ) if superop is None or superop is NotImplemented: - return NotImplemented + return None n = np.prod(qid_shape) ** 2 kraus_ops = qis.superoperator_to_kraus(superop.reshape((n, n))) return tuple(kraus_ops) @@ -194,7 +194,7 @@ def kraus( # Last-resort fallback: try to derive Kraus from _apply_channel_ result = _strat_kraus_from_apply_channel(val) - if result is not NotImplemented: + if result is not None: return result if kraus_getter is None and unitary_getter is None and mixture_getter is None: From 6e15904cde13f7cceea94d958deccf14b5d6993b Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Tue, 24 Jun 2025 23:57:50 +0900 Subject: [PATCH 13/22] . --- cirq-core/cirq/protocols/kraus_protocol_test.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 57ebbc7d0e0..98b174470ee 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -198,3 +198,10 @@ def _apply_channel_(self, args: cirq.ApplyChannelArgs): expected_super = sum(np.kron(k, k.conj()) for k in expected_kraus) actual_super = sum(np.kron(k, k.conj()) for k in kraus_ops) np.testing.assert_allclose(actual_super, expected_super, atol=1e-8) + + +def test_strat_kraus_from_apply_channel_returns_none(): + from cirq.protocols.kraus_protocol import _strat_kraus_from_apply_channel + class NoApplyChannel: + pass + assert _strat_kraus_from_apply_channel(NoApplyChannel()) is None From 3a1f45e224041fe2d1069d717809326b7461666e Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Wed, 25 Jun 2025 00:02:23 +0900 Subject: [PATCH 14/22] . --- cirq-core/cirq/protocols/kraus_protocol_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 98b174470ee..b5b518f1e34 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -202,6 +202,8 @@ def _apply_channel_(self, args: cirq.ApplyChannelArgs): def test_strat_kraus_from_apply_channel_returns_none(): from cirq.protocols.kraus_protocol import _strat_kraus_from_apply_channel + class NoApplyChannel: pass + assert _strat_kraus_from_apply_channel(NoApplyChannel()) is None From c09d37cd12335c74aa8ad5839cc607976990c5fb Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Wed, 25 Jun 2025 00:24:35 +0900 Subject: [PATCH 15/22] . --- cirq-core/cirq/protocols/kraus_protocol_test.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index b5b518f1e34..b298ef5edc9 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -203,7 +203,11 @@ def _apply_channel_(self, args: cirq.ApplyChannelArgs): def test_strat_kraus_from_apply_channel_returns_none(): from cirq.protocols.kraus_protocol import _strat_kraus_from_apply_channel - class NoApplyChannel: - pass + class ApplyChannelReturnsNone: + def _apply_channel_(self, *args, **kwargs): + return None + + def _num_qubits_(self): + return 1 # Needed for qid_shape - assert _strat_kraus_from_apply_channel(NoApplyChannel()) is None + assert _strat_kraus_from_apply_channel(ApplyChannelReturnsNone()) is None From db5f020ab410f97453fdc46724096adcfb6b2709 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Wed, 25 Jun 2025 00:45:18 +0900 Subject: [PATCH 16/22] . --- .../pauli_string_measurement_with_readout_mitigation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py index 0928db748f3..e47d3b5ff6c 100644 --- a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py +++ b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py @@ -325,7 +325,7 @@ def _process_pauli_measurement_results( pauli_readout_qubits = _extract_readout_qubits(pauli_strs) calibration_result = ( - calibration_results[tuple(pauli_readout_qubits)] + calibration_results.get(tuple(pauli_readout_qubits), None) if disable_readout_mitigation is False else None ) From d616813534967254a51602979e86b464d8156ab2 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Wed, 25 Jun 2025 23:30:09 +0900 Subject: [PATCH 17/22] reuse applying kraus operator --- .../cirq/protocols/kraus_protocol_test.py | 66 +++++++++++++------ cirq-core/cirq/qis/channels_test.py | 10 +-- cirq-core/cirq/testing/circuit_compare.py | 19 ++++++ 3 files changed, 66 insertions(+), 29 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index b298ef5edc9..a8d198186d8 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -22,6 +22,7 @@ import pytest import cirq +from cirq.testing.circuit_compare import apply_kraus_operators LOCAL_DEFAULT: list[np.ndarray] = [np.array([])] @@ -173,14 +174,25 @@ def test_has_kraus_when_decomposed(decomposed_cls) -> None: assert not cirq.has_kraus(op, allow_decompose=False) -def test_kraus_fallback_to_apply_channel() -> None: - """Kraus protocol falls back to _apply_channel_ when no _kraus_, _mixture_, or _unitary_.""" +def test_strat_kraus_from_apply_channel_returns_none(): + from cirq.protocols.kraus_protocol import _strat_kraus_from_apply_channel + + class ApplyChannelReturnsNone: + def _apply_channel_(self, *args, **kwargs): + return None + + def _num_qubits_(self): + return 1 # Needed for qid_shape + + assert _strat_kraus_from_apply_channel(ApplyChannelReturnsNone()) is None + + +def test_kraus_fallback_to_apply_channel_bitflipchannel_real() -> None: + """Test fallback using the real cirq.BitFlipChannel and compare to a custom channel.""" p = 0.5 - K0 = np.sqrt(1 - p) * np.eye(2) - K1 = np.sqrt(p) * np.array([[0, 1], [1, 0]]) - expected_kraus = (K0, K1) + expected_kraus = cirq.kraus(cirq.BitFlipChannel(p)) - class BitFlipChannel: + class CustomBitFlipChannel: def _num_qubits_(self): return 1 @@ -191,23 +203,37 @@ def _apply_channel_(self, args: cirq.ApplyChannelArgs): args.out_buffer[...] = out return args.out_buffer - chan = BitFlipChannel() - kraus_ops = cirq.kraus(chan) + kraus_ops = cirq.kraus(CustomBitFlipChannel()) + # Compare the action on a test density matrix + rho = np.array([[0.7, 0.2], [0.2, 0.3]], dtype=np.complex128) + expected_rho = apply_kraus_operators(expected_kraus, rho) + actual_rho = apply_kraus_operators(kraus_ops, rho) + np.testing.assert_allclose(actual_rho, expected_rho, atol=1e-8) - # Compare the superoperator matrices for equivalence - expected_super = sum(np.kron(k, k.conj()) for k in expected_kraus) - actual_super = sum(np.kron(k, k.conj()) for k in kraus_ops) - np.testing.assert_allclose(actual_super, expected_super, atol=1e-8) +def test_reset_channel_kraus_apply_channel_consistency(): + """Test that ResetChannel's _kraus_ and _apply_channel_ produce the same channel, even if one is missing.""" + Reset = cirq.ResetChannel + # Original gate + gate = Reset() + cirq.testing.assert_has_consistent_apply_channel(gate) + cirq.testing.assert_consistent_channel(gate) -def test_strat_kraus_from_apply_channel_returns_none(): - from cirq.protocols.kraus_protocol import _strat_kraus_from_apply_channel + # Remove _kraus_ method + class NoKrausReset(Reset): + def _kraus_(self): + return NotImplemented - class ApplyChannelReturnsNone: - def _apply_channel_(self, *args, **kwargs): - return None + gate_no_kraus = NoKrausReset() + # Should still match the original superoperator + np.testing.assert_allclose(cirq.kraus(gate), cirq.kraus(gate_no_kraus), atol=1e-8) - def _num_qubits_(self): - return 1 # Needed for qid_shape + # Remove _apply_channel_ method + class NoApplyChannelReset(Reset): + def _apply_channel_(self, args): + return NotImplemented - assert _strat_kraus_from_apply_channel(ApplyChannelReturnsNone()) is None + gate_no_apply = NoApplyChannelReset() + cirq.testing.assert_consistent_channel(gate_no_apply) + # Should still match the original superoperator + np.testing.assert_allclose(cirq.kraus(gate), cirq.kraus(gate_no_apply), atol=1e-8) diff --git a/cirq-core/cirq/qis/channels_test.py b/cirq-core/cirq/qis/channels_test.py index d24108d6437..41561d95242 100644 --- a/cirq-core/cirq/qis/channels_test.py +++ b/cirq-core/cirq/qis/channels_test.py @@ -22,21 +22,13 @@ import pytest import cirq +from cirq.testing.circuit_compare import apply_kraus_operators def apply_channel(channel: cirq.SupportsKraus, rho: np.ndarray) -> np.ndarray: return apply_kraus_operators(cirq.kraus(channel), rho) -def apply_kraus_operators(kraus_operators: Sequence[np.ndarray], rho: np.ndarray) -> np.ndarray: - d_out, d_in = kraus_operators[0].shape - assert rho.shape == (d_in, d_in) - out = np.zeros((d_out, d_out), dtype=np.complex128) - for k in kraus_operators: - out += k @ rho @ k.conj().T - return out - - def generate_standard_operator_basis(d_out: int, d_in: int) -> Iterable[np.ndarray]: for i in range(d_out): for j in range(d_in): diff --git a/cirq-core/cirq/testing/circuit_compare.py b/cirq-core/cirq/testing/circuit_compare.py index ea0870d0bfb..5dbc1f98aa8 100644 --- a/cirq-core/cirq/testing/circuit_compare.py +++ b/cirq-core/cirq/testing/circuit_compare.py @@ -491,3 +491,22 @@ def assert_has_consistent_qid_shape(val: Any) -> None: assert num_qubits == len( val.qubits ), f'Length of num_qubits and val.qubits disagrees: {num_qubits}, {len(val.qubits)}' + + +def apply_kraus_operators(kraus_operators: Sequence[np.ndarray], rho: np.ndarray) -> np.ndarray: + """ + Applies a quantum channel (in Kraus operator form) to a density matrix. + + Args: + kraus_operators: Sequence of Kraus operators specifying the channel. + rho: The input density matrix. + + Returns: + The output density matrix after the channel is applied. + """ + d_out, d_in = kraus_operators[0].shape + assert rho.shape == (d_in, d_in) + out = np.zeros((d_out, d_out), dtype=np.complex128) + for k in kraus_operators: + out += k @ rho @ k.conj().T + return out From c132b3fb8261bbb0a4e5f81680ab631f98660db7 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Wed, 25 Jun 2025 23:37:24 +0900 Subject: [PATCH 18/22] . --- cirq-core/cirq/protocols/kraus_protocol_test.py | 3 ++- cirq-core/cirq/qis/channels_test.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index a8d198186d8..8c9e88b65ab 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -212,7 +212,8 @@ def _apply_channel_(self, args: cirq.ApplyChannelArgs): def test_reset_channel_kraus_apply_channel_consistency(): - """Test that ResetChannel's _kraus_ and _apply_channel_ produce the same channel, even if one is missing.""" + """Test that ResetChannel's _kraus_ and _apply_channel_ produce the same channel, + even if one is missing.""" Reset = cirq.ResetChannel # Original gate gate = Reset() diff --git a/cirq-core/cirq/qis/channels_test.py b/cirq-core/cirq/qis/channels_test.py index 41561d95242..af3445ec875 100644 --- a/cirq-core/cirq/qis/channels_test.py +++ b/cirq-core/cirq/qis/channels_test.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Iterable, Sequence +from typing import Iterable import numpy as np import pytest From 97bc7933b0ca95a917d9d6c33deb41a82d9ffb23 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Wed, 25 Jun 2025 23:37:24 +0900 Subject: [PATCH 19/22] . --- cirq-core/cirq/protocols/kraus_protocol_test.py | 3 ++- cirq-core/cirq/qis/channels_test.py | 2 +- cirq-core/cirq/testing/circuit_compare.py | 3 +-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index a8d198186d8..8c9e88b65ab 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -212,7 +212,8 @@ def _apply_channel_(self, args: cirq.ApplyChannelArgs): def test_reset_channel_kraus_apply_channel_consistency(): - """Test that ResetChannel's _kraus_ and _apply_channel_ produce the same channel, even if one is missing.""" + """Test that ResetChannel's _kraus_ and _apply_channel_ produce the same channel, + even if one is missing.""" Reset = cirq.ResetChannel # Original gate gate = Reset() diff --git a/cirq-core/cirq/qis/channels_test.py b/cirq-core/cirq/qis/channels_test.py index 41561d95242..af3445ec875 100644 --- a/cirq-core/cirq/qis/channels_test.py +++ b/cirq-core/cirq/qis/channels_test.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Iterable, Sequence +from typing import Iterable import numpy as np import pytest diff --git a/cirq-core/cirq/testing/circuit_compare.py b/cirq-core/cirq/testing/circuit_compare.py index 5dbc1f98aa8..4293a9acea9 100644 --- a/cirq-core/cirq/testing/circuit_compare.py +++ b/cirq-core/cirq/testing/circuit_compare.py @@ -494,8 +494,7 @@ def assert_has_consistent_qid_shape(val: Any) -> None: def apply_kraus_operators(kraus_operators: Sequence[np.ndarray], rho: np.ndarray) -> np.ndarray: - """ - Applies a quantum channel (in Kraus operator form) to a density matrix. + """Applies a quantum channel (in Kraus operator form) to a density matrix. Args: kraus_operators: Sequence of Kraus operators specifying the channel. From c30ae4ec9568857c8d2474039e827da5a6710653 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Thu, 26 Jun 2025 00:18:46 +0900 Subject: [PATCH 20/22] . --- .../cirq/protocols/kraus_protocol_test.py | 33 +++++++------------ 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 8c9e88b65ab..4ee35bdd433 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -175,20 +175,23 @@ def test_has_kraus_when_decomposed(decomposed_cls) -> None: def test_strat_kraus_from_apply_channel_returns_none(): - from cirq.protocols.kraus_protocol import _strat_kraus_from_apply_channel - - class ApplyChannelReturnsNone: - def _apply_channel_(self, *args, **kwargs): - return None + # Remove _kraus_ and _apply_channel_ methods + class NoApplyChannelReset(cirq.ResetChannel): + def _kraus_(self): + return NotImplemented - def _num_qubits_(self): - return 1 # Needed for qid_shape + def _apply_channel_(self, args): + return NotImplemented - assert _strat_kraus_from_apply_channel(ApplyChannelReturnsNone()) is None + gate_no_apply = NoApplyChannelReset() + with pytest.raises( + TypeError, + match="does have a _kraus_, _mixture_ or _unitary_ method, but it returned NotImplemented", + ): + cirq.kraus(gate_no_apply) def test_kraus_fallback_to_apply_channel_bitflipchannel_real() -> None: - """Test fallback using the real cirq.BitFlipChannel and compare to a custom channel.""" p = 0.5 expected_kraus = cirq.kraus(cirq.BitFlipChannel(p)) @@ -212,8 +215,6 @@ def _apply_channel_(self, args: cirq.ApplyChannelArgs): def test_reset_channel_kraus_apply_channel_consistency(): - """Test that ResetChannel's _kraus_ and _apply_channel_ produce the same channel, - even if one is missing.""" Reset = cirq.ResetChannel # Original gate gate = Reset() @@ -228,13 +229,3 @@ def _kraus_(self): gate_no_kraus = NoKrausReset() # Should still match the original superoperator np.testing.assert_allclose(cirq.kraus(gate), cirq.kraus(gate_no_kraus), atol=1e-8) - - # Remove _apply_channel_ method - class NoApplyChannelReset(Reset): - def _apply_channel_(self, args): - return NotImplemented - - gate_no_apply = NoApplyChannelReset() - cirq.testing.assert_consistent_channel(gate_no_apply) - # Should still match the original superoperator - np.testing.assert_allclose(cirq.kraus(gate), cirq.kraus(gate_no_apply), atol=1e-8) From c88b5688502542c6fd075a2f0e9df50f831328b1 Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Thu, 26 Jun 2025 09:45:52 +0900 Subject: [PATCH 21/22] . --- .../cirq/protocols/kraus_protocol_test.py | 48 ++++++++++++------- cirq-core/cirq/qis/channels_test.py | 12 ++++- cirq-core/cirq/testing/circuit_compare.py | 18 ------- 3 files changed, 41 insertions(+), 37 deletions(-) diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 4ee35bdd433..8f8b4ad2ade 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -22,7 +22,7 @@ import pytest import cirq -from cirq.testing.circuit_compare import apply_kraus_operators +from cirq.protocols.apply_channel_protocol import _apply_kraus LOCAL_DEFAULT: list[np.ndarray] = [np.array([])] @@ -191,27 +191,41 @@ def _apply_channel_(self, args): cirq.kraus(gate_no_apply) -def test_kraus_fallback_to_apply_channel_bitflipchannel_real() -> None: - p = 0.5 - expected_kraus = cirq.kraus(cirq.BitFlipChannel(p)) +@pytest.mark.parametrize( + 'channel_cls,params', + [ + (cirq.BitFlipChannel, (0.5,)), + (cirq.PhaseFlipChannel, (0.3,)), + (cirq.DepolarizingChannel, (0.2,)), + (cirq.AmplitudeDampingChannel, (0.4,)), + (cirq.PhaseDampingChannel, (0.25,)), + ], +) +def test_kraus_fallback_to_apply_channel(channel_cls, params) -> None: + """Kraus protocol falls back to _apply_channel_ when no _kraus_, _mixture_, or _unitary_.""" + # Create the expected channel and get its Kraus operators + expected_channel = channel_cls(*params) + expected_kraus = cirq.kraus(expected_channel) + + class TestChannel: + def __init__(self, channel_cls, params): + self.channel_cls = channel_cls + self.params = params + self.expected_kraus = cirq.kraus(channel_cls(*params)) - class CustomBitFlipChannel: def _num_qubits_(self): return 1 def _apply_channel_(self, args: cirq.ApplyChannelArgs): - X = np.array([[0, 1], [1, 0]], dtype=np.complex128) - rho = args.target_tensor - out = (1 - p) * rho + p * X @ rho @ X - args.out_buffer[...] = out - return args.out_buffer - - kraus_ops = cirq.kraus(CustomBitFlipChannel()) - # Compare the action on a test density matrix - rho = np.array([[0.7, 0.2], [0.2, 0.3]], dtype=np.complex128) - expected_rho = apply_kraus_operators(expected_kraus, rho) - actual_rho = apply_kraus_operators(kraus_ops, rho) - np.testing.assert_allclose(actual_rho, expected_rho, atol=1e-8) + return _apply_kraus(self.expected_kraus, args) + + chan = TestChannel(channel_cls, params) + kraus_ops = cirq.kraus(chan) + + # Compare the superoperator matrices for equivalence + expected_super = sum(np.kron(k, k.conj()) for k in expected_kraus) + actual_super = sum(np.kron(k, k.conj()) for k in kraus_ops) + np.testing.assert_allclose(actual_super, expected_super, atol=1e-8) def test_reset_channel_kraus_apply_channel_consistency(): diff --git a/cirq-core/cirq/qis/channels_test.py b/cirq-core/cirq/qis/channels_test.py index af3445ec875..d24108d6437 100644 --- a/cirq-core/cirq/qis/channels_test.py +++ b/cirq-core/cirq/qis/channels_test.py @@ -16,19 +16,27 @@ from __future__ import annotations -from typing import Iterable +from typing import Iterable, Sequence import numpy as np import pytest import cirq -from cirq.testing.circuit_compare import apply_kraus_operators def apply_channel(channel: cirq.SupportsKraus, rho: np.ndarray) -> np.ndarray: return apply_kraus_operators(cirq.kraus(channel), rho) +def apply_kraus_operators(kraus_operators: Sequence[np.ndarray], rho: np.ndarray) -> np.ndarray: + d_out, d_in = kraus_operators[0].shape + assert rho.shape == (d_in, d_in) + out = np.zeros((d_out, d_out), dtype=np.complex128) + for k in kraus_operators: + out += k @ rho @ k.conj().T + return out + + def generate_standard_operator_basis(d_out: int, d_in: int) -> Iterable[np.ndarray]: for i in range(d_out): for j in range(d_in): diff --git a/cirq-core/cirq/testing/circuit_compare.py b/cirq-core/cirq/testing/circuit_compare.py index 4293a9acea9..ea0870d0bfb 100644 --- a/cirq-core/cirq/testing/circuit_compare.py +++ b/cirq-core/cirq/testing/circuit_compare.py @@ -491,21 +491,3 @@ def assert_has_consistent_qid_shape(val: Any) -> None: assert num_qubits == len( val.qubits ), f'Length of num_qubits and val.qubits disagrees: {num_qubits}, {len(val.qubits)}' - - -def apply_kraus_operators(kraus_operators: Sequence[np.ndarray], rho: np.ndarray) -> np.ndarray: - """Applies a quantum channel (in Kraus operator form) to a density matrix. - - Args: - kraus_operators: Sequence of Kraus operators specifying the channel. - rho: The input density matrix. - - Returns: - The output density matrix after the channel is applied. - """ - d_out, d_in = kraus_operators[0].shape - assert rho.shape == (d_in, d_in) - out = np.zeros((d_out, d_out), dtype=np.complex128) - for k in kraus_operators: - out += k @ rho @ k.conj().T - return out From 4a65fd5a83458178e0efa02592ed9a4e047119de Mon Sep 17 00:00:00 2001 From: iamsusiep Date: Thu, 3 Jul 2025 22:03:21 +0900 Subject: [PATCH 22/22] . --- .../pauli_string_measurement_with_readout_mitigation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py index e47d3b5ff6c..0928db748f3 100644 --- a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py +++ b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py @@ -325,7 +325,7 @@ def _process_pauli_measurement_results( pauli_readout_qubits = _extract_readout_qubits(pauli_strs) calibration_result = ( - calibration_results.get(tuple(pauli_readout_qubits), None) + calibration_results[tuple(pauli_readout_qubits)] if disable_readout_mitigation is False else None )