Skip to content

Derive _kraus_ from _apply_channel_ #7434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1ea89fe
add changes to apply_unitary and unitary to be able to take in numpy …
iamsusiep Jun 12, 2025
f0f8e0e
Merge branch 'main' into protocol
iamsusiep Jun 12, 2025
3ec6d67
add changes to apply_unitary and unitary to be able to take in numpy …
iamsusiep Jun 12, 2025
a4d0435
Merge branch 'protocol' of https://github.com/iamsusiep/Cirq into pro…
iamsusiep Jun 12, 2025
381f105
handle nd array in has_unitary and address pr comment
iamsusiep Jun 12, 2025
5ed34ef
fix related tests
iamsusiep Jun 12, 2025
2044d0a
lint formatting etc
iamsusiep Jun 12, 2025
055245c
add coverage
iamsusiep Jun 12, 2025
175c719
fix test case
iamsusiep Jun 12, 2025
d4d5eae
Merge remote-tracking branch 'upstream/main' into protocol
iamsusiep Jun 18, 2025
1ddd07c
update sycamore test based on pr comment
iamsusiep Jun 18, 2025
021f09b
update kraus protocol to add apply_channel fallback
iamsusiep Jun 24, 2025
dec375d
Merge remote-tracking branch 'upstream/main' into protocol
iamsusiep Jun 24, 2025
fdcd595
.
iamsusiep Jun 24, 2025
3bf33eb
.
iamsusiep Jun 24, 2025
69c2014
.
iamsusiep Jun 24, 2025
6e15904
.
iamsusiep Jun 24, 2025
3a1f45e
.
iamsusiep Jun 24, 2025
c09d37c
.
iamsusiep Jun 24, 2025
db5f020
.
iamsusiep Jun 24, 2025
d616813
reuse applying kraus operator
iamsusiep Jun 25, 2025
c132b3f
.
iamsusiep Jun 25, 2025
97bc793
.
iamsusiep Jun 25, 2025
c1374b3
Merge branch 'protocol' of https://github.com/iamsusiep/Cirq into pro…
iamsusiep Jun 25, 2025
c30ae4e
.
iamsusiep Jun 25, 2025
84224d8
Merge branch 'main' into protocol
iamsusiep Jun 25, 2025
c88b568
.
iamsusiep Jun 26, 2025
0840625
Merge branch 'protocol' of https://github.com/iamsusiep/Cirq into pro…
iamsusiep Jun 26, 2025
6601fe0
Merge remote-tracking branch 'upstream/main' into protocol
iamsusiep Jul 3, 2025
4a65fd5
.
iamsusiep Jul 3, 2025
c8885c2
Merge branch 'main' into protocol
iamsusiep Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions cirq-core/cirq/protocols/kraus_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
from typing_extensions import Protocol

from cirq import protocols, qis
from cirq._doc import doc_private
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
from cirq.protocols.mixture_protocol import has_mixture
Expand Down Expand Up @@ -94,6 +95,35 @@ def _has_kraus_(self) -> bool:
"""


def _strat_kraus_from_apply_channel(val: Any) -> tuple[np.ndarray, ...] | None:
"""Attempts to compute a value's Kraus operators via its _apply_channel_ method.
This is very expensive (O(16^N)), so only do this as a last resort."""
method = getattr(val, '_apply_channel_', None)
if method is None:
return None

qid_shape = protocols.qid_shape(val)

eye = qis.eye_tensor(qid_shape * 2, dtype=np.complex128)
superop = protocols.apply_channel(
val=val,
args=protocols.ApplyChannelArgs(
target_tensor=eye,
out_buffer=np.ones_like(eye) * float('nan'),
auxiliary_buffer0=np.ones_like(eye) * float('nan'),
auxiliary_buffer1=np.ones_like(eye) * float('nan'),
left_axes=list(range(len(qid_shape))),
right_axes=list(range(len(qid_shape), len(qid_shape) * 2)),
),
default=None,
)
if superop is None or superop is NotImplemented:
return None
n = np.prod(qid_shape) ** 2
kraus_ops = qis.superoperator_to_kraus(superop.reshape((n, n)))
return tuple(kraus_ops)


def kraus(
val: Any, default: Any = RaiseTypeErrorIfNotProvided
) -> tuple[np.ndarray, ...] | TDefault:
Expand Down Expand Up @@ -162,6 +192,11 @@ def kraus(
if default is not RaiseTypeErrorIfNotProvided:
return default

# Last-resort fallback: try to derive Kraus from _apply_channel_
result = _strat_kraus_from_apply_channel(val)
if result is not None:
return result

if kraus_getter is None and unitary_getter is None and mixture_getter is None:
raise TypeError(
f"object of type '{type(val)}' has no _kraus_ or _mixture_ or _unitary_ method."
Expand Down
72 changes: 72 additions & 0 deletions cirq-core/cirq/protocols/kraus_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest

import cirq
from cirq.protocols.apply_channel_protocol import _apply_kraus

LOCAL_DEFAULT: list[np.ndarray] = [np.array([])]

Expand Down Expand Up @@ -171,3 +172,74 @@ def test_has_kraus_when_decomposed(decomposed_cls) -> None:
op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit('test'))
assert cirq.has_kraus(op)
assert not cirq.has_kraus(op, allow_decompose=False)


def test_strat_kraus_from_apply_channel_returns_none():
# Remove _kraus_ and _apply_channel_ methods
class NoApplyChannelReset(cirq.ResetChannel):
def _kraus_(self):
return NotImplemented

def _apply_channel_(self, args):
return NotImplemented

gate_no_apply = NoApplyChannelReset()
with pytest.raises(
TypeError,
match="does have a _kraus_, _mixture_ or _unitary_ method, but it returned NotImplemented",
):
cirq.kraus(gate_no_apply)


@pytest.mark.parametrize(
'channel_cls,params',
[
(cirq.BitFlipChannel, (0.5,)),
(cirq.PhaseFlipChannel, (0.3,)),
(cirq.DepolarizingChannel, (0.2,)),
(cirq.AmplitudeDampingChannel, (0.4,)),
(cirq.PhaseDampingChannel, (0.25,)),
],
)
def test_kraus_fallback_to_apply_channel(channel_cls, params) -> None:
"""Kraus protocol falls back to _apply_channel_ when no _kraus_, _mixture_, or _unitary_."""
# Create the expected channel and get its Kraus operators
expected_channel = channel_cls(*params)
expected_kraus = cirq.kraus(expected_channel)

class TestChannel:
def __init__(self, channel_cls, params):
self.channel_cls = channel_cls
self.params = params
self.expected_kraus = cirq.kraus(channel_cls(*params))

def _num_qubits_(self):
return 1

def _apply_channel_(self, args: cirq.ApplyChannelArgs):
return _apply_kraus(self.expected_kraus, args)

chan = TestChannel(channel_cls, params)
kraus_ops = cirq.kraus(chan)

# Compare the superoperator matrices for equivalence
expected_super = sum(np.kron(k, k.conj()) for k in expected_kraus)
actual_super = sum(np.kron(k, k.conj()) for k in kraus_ops)
np.testing.assert_allclose(actual_super, expected_super, atol=1e-8)


def test_reset_channel_kraus_apply_channel_consistency():
Reset = cirq.ResetChannel
# Original gate
gate = Reset()
cirq.testing.assert_has_consistent_apply_channel(gate)
cirq.testing.assert_consistent_channel(gate)

# Remove _kraus_ method
class NoKrausReset(Reset):
def _kraus_(self):
return NotImplemented

gate_no_kraus = NoKrausReset()
# Should still match the original superoperator
np.testing.assert_allclose(cirq.kraus(gate), cirq.kraus(gate_no_kraus), atol=1e-8)
1 change: 1 addition & 0 deletions cirq-core/cirq/testing/circuit_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def assert_has_consistent_apply_channel(val: Any, *, atol: float = 1e-8) -> None
atol: Absolute error tolerance.
"""
__tracebackhide__ = True
assert hasattr(val, '_apply_channel_')

kraus = protocols.kraus(val, default=None)
expected = qis.kraus_to_superoperator(kraus) if kraus is not None else None
Expand Down