diff --git a/cirq-core/cirq/protocols/kraus_protocol.py b/cirq-core/cirq/protocols/kraus_protocol.py index 32f0da9d1cf..e3b90c50812 100644 --- a/cirq-core/cirq/protocols/kraus_protocol.py +++ b/cirq-core/cirq/protocols/kraus_protocol.py @@ -23,6 +23,7 @@ import numpy as np from typing_extensions import Protocol +from cirq import protocols, qis from cirq._doc import doc_private from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits from cirq.protocols.mixture_protocol import has_mixture @@ -94,6 +95,35 @@ def _has_kraus_(self) -> bool: """ +def _strat_kraus_from_apply_channel(val: Any) -> tuple[np.ndarray, ...] | None: + """Attempts to compute a value's Kraus operators via its _apply_channel_ method. + This is very expensive (O(16^N)), so only do this as a last resort.""" + method = getattr(val, '_apply_channel_', None) + if method is None: + return None + + qid_shape = protocols.qid_shape(val) + + eye = qis.eye_tensor(qid_shape * 2, dtype=np.complex128) + superop = protocols.apply_channel( + val=val, + args=protocols.ApplyChannelArgs( + target_tensor=eye, + out_buffer=np.ones_like(eye) * float('nan'), + auxiliary_buffer0=np.ones_like(eye) * float('nan'), + auxiliary_buffer1=np.ones_like(eye) * float('nan'), + left_axes=list(range(len(qid_shape))), + right_axes=list(range(len(qid_shape), len(qid_shape) * 2)), + ), + default=None, + ) + if superop is None or superop is NotImplemented: + return None + n = np.prod(qid_shape) ** 2 + kraus_ops = qis.superoperator_to_kraus(superop.reshape((n, n))) + return tuple(kraus_ops) + + def kraus( val: Any, default: Any = RaiseTypeErrorIfNotProvided ) -> tuple[np.ndarray, ...] | TDefault: @@ -162,6 +192,11 @@ def kraus( if default is not RaiseTypeErrorIfNotProvided: return default + # Last-resort fallback: try to derive Kraus from _apply_channel_ + result = _strat_kraus_from_apply_channel(val) + if result is not None: + return result + if kraus_getter is None and unitary_getter is None and mixture_getter is None: raise TypeError( f"object of type '{type(val)}' has no _kraus_ or _mixture_ or _unitary_ method." diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 7ce4dd900fd..8f8b4ad2ade 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -22,6 +22,7 @@ import pytest import cirq +from cirq.protocols.apply_channel_protocol import _apply_kraus LOCAL_DEFAULT: list[np.ndarray] = [np.array([])] @@ -171,3 +172,74 @@ def test_has_kraus_when_decomposed(decomposed_cls) -> None: op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit('test')) assert cirq.has_kraus(op) assert not cirq.has_kraus(op, allow_decompose=False) + + +def test_strat_kraus_from_apply_channel_returns_none(): + # Remove _kraus_ and _apply_channel_ methods + class NoApplyChannelReset(cirq.ResetChannel): + def _kraus_(self): + return NotImplemented + + def _apply_channel_(self, args): + return NotImplemented + + gate_no_apply = NoApplyChannelReset() + with pytest.raises( + TypeError, + match="does have a _kraus_, _mixture_ or _unitary_ method, but it returned NotImplemented", + ): + cirq.kraus(gate_no_apply) + + +@pytest.mark.parametrize( + 'channel_cls,params', + [ + (cirq.BitFlipChannel, (0.5,)), + (cirq.PhaseFlipChannel, (0.3,)), + (cirq.DepolarizingChannel, (0.2,)), + (cirq.AmplitudeDampingChannel, (0.4,)), + (cirq.PhaseDampingChannel, (0.25,)), + ], +) +def test_kraus_fallback_to_apply_channel(channel_cls, params) -> None: + """Kraus protocol falls back to _apply_channel_ when no _kraus_, _mixture_, or _unitary_.""" + # Create the expected channel and get its Kraus operators + expected_channel = channel_cls(*params) + expected_kraus = cirq.kraus(expected_channel) + + class TestChannel: + def __init__(self, channel_cls, params): + self.channel_cls = channel_cls + self.params = params + self.expected_kraus = cirq.kraus(channel_cls(*params)) + + def _num_qubits_(self): + return 1 + + def _apply_channel_(self, args: cirq.ApplyChannelArgs): + return _apply_kraus(self.expected_kraus, args) + + chan = TestChannel(channel_cls, params) + kraus_ops = cirq.kraus(chan) + + # Compare the superoperator matrices for equivalence + expected_super = sum(np.kron(k, k.conj()) for k in expected_kraus) + actual_super = sum(np.kron(k, k.conj()) for k in kraus_ops) + np.testing.assert_allclose(actual_super, expected_super, atol=1e-8) + + +def test_reset_channel_kraus_apply_channel_consistency(): + Reset = cirq.ResetChannel + # Original gate + gate = Reset() + cirq.testing.assert_has_consistent_apply_channel(gate) + cirq.testing.assert_consistent_channel(gate) + + # Remove _kraus_ method + class NoKrausReset(Reset): + def _kraus_(self): + return NotImplemented + + gate_no_kraus = NoKrausReset() + # Should still match the original superoperator + np.testing.assert_allclose(cirq.kraus(gate), cirq.kraus(gate_no_kraus), atol=1e-8) diff --git a/cirq-core/cirq/testing/circuit_compare.py b/cirq-core/cirq/testing/circuit_compare.py index 3f927a5579f..ea0870d0bfb 100644 --- a/cirq-core/cirq/testing/circuit_compare.py +++ b/cirq-core/cirq/testing/circuit_compare.py @@ -336,6 +336,7 @@ def assert_has_consistent_apply_channel(val: Any, *, atol: float = 1e-8) -> None atol: Absolute error tolerance. """ __tracebackhide__ = True + assert hasattr(val, '_apply_channel_') kraus = protocols.kraus(val, default=None) expected = qis.kraus_to_superoperator(kraus) if kraus is not None else None