diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 1f0301160de..abaff46a6ef 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -125,10 +125,7 @@ def with_qubits(self, *new_qubits): *self._conditions ) - def _decompose_(self): - return self._decompose_with_context_() - - def _decompose_with_context_(self, context: cirq.DecompositionContext | None = None): + def _decompose_with_context_(self, *, context: cirq.DecompositionContext): result = protocols.decompose_once( self._sub_operation, NotImplemented, flatten=False, context=context ) diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index 4ed6a03b876..f57a57b60b7 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -139,11 +139,8 @@ def num_controls(self) -> int: def _qid_shape_(self) -> tuple[int, ...]: return self.control_qid_shape + protocols.qid_shape(self.sub_gate) - def _decompose_(self, qubits: tuple[cirq.Qid, ...]) -> None | NotImplementedType | cirq.OP_TREE: - return self._decompose_with_context_(qubits) - def _decompose_with_context_( - self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext | None = None + self, qubits: tuple[cirq.Qid, ...], *, context: cirq.DecompositionContext ) -> None | NotImplementedType | cirq.OP_TREE: control_qubits = list(qubits[: self.num_controls()]) controlled_sub_gate = self.sub_gate.controlled( diff --git a/cirq-core/cirq/ops/controlled_operation.py b/cirq-core/cirq/ops/controlled_operation.py index b133b5c451a..e1fb0b604a6 100644 --- a/cirq-core/cirq/ops/controlled_operation.py +++ b/cirq-core/cirq/ops/controlled_operation.py @@ -134,10 +134,7 @@ def with_qubits(self, *new_qubits): new_qubits[:n], self.sub_operation.with_qubits(*new_qubits[n:]), self.control_values ) - def _decompose_(self): - return self._decompose_with_context_() - - def _decompose_with_context_(self, context: cirq.DecompositionContext | None = None): + def _decompose_with_context_(self, *, context: cirq.DecompositionContext): result = protocols.decompose_once_with_qubits( self.gate, self.qubits, NotImplemented, flatten=False, context=context ) diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index 925bf262ed2..eceee35d5ac 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -148,11 +148,8 @@ def _qid_shape_(self): def _num_qubits_(self): return len(self._qubits) - def _decompose_(self) -> cirq.OP_TREE: - return self._decompose_with_context_() - def _decompose_with_context_( - self, context: cirq.DecompositionContext | None = None + self, *, context: cirq.DecompositionContext ) -> cirq.OP_TREE: return protocols.decompose_once_with_qubits( self.gate, self.qubits, NotImplemented, flatten=False, context=context diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index c84a6928f02..3b7107d315d 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -828,11 +828,8 @@ def _from_json_dict_(cls, sub_operation, tags, **kwargs): def _json_dict_(self) -> dict[str, Any]: return protocols.obj_to_dict_helper(self, ['sub_operation', 'tags']) - def _decompose_(self) -> cirq.OP_TREE: - return self._decompose_with_context_() - def _decompose_with_context_( - self, context: cirq.DecompositionContext | None = None + self, *, context: cirq.DecompositionContext ) -> cirq.OP_TREE: return protocols.decompose_once( self.sub_operation, default=None, flatten=False, context=context @@ -986,7 +983,7 @@ def _decompose_(self, qubits): return self._decompose_with_context_(qubits) def _decompose_with_context_( - self, qubits: Sequence[cirq.Qid], context: cirq.DecompositionContext | None = None + self, qubits: Sequence[cirq.Qid], *, context: cirq.DecompositionContext ) -> cirq.OP_TREE: return protocols.inverse( protocols.decompose_once_with_qubits(self._original, qubits, context=context) diff --git a/cirq-core/cirq/protocols/decompose_protocol.py b/cirq-core/cirq/protocols/decompose_protocol.py index 3d552df311f..f8ea151f420 100644 --- a/cirq-core/cirq/protocols/decompose_protocol.py +++ b/cirq-core/cirq/protocols/decompose_protocol.py @@ -54,7 +54,7 @@ @runtime_checkable class OpDecomposerWithContext(Protocol): def __call__( - self, __op: cirq.Operation, *, context: cirq.DecompositionContext | None = None + self, __op: cirq.Operation, *, context: cirq.DecompositionContext ) -> DecomposeResult: ... @@ -126,7 +126,7 @@ def _decompose_(self) -> DecomposeResult: pass def _decompose_with_context_( - self, *, context: DecompositionContext | None = None + self, *, context: DecompositionContext ) -> DecomposeResult: pass @@ -154,13 +154,13 @@ def _decompose_(self, qubits: tuple[cirq.Qid, ...]) -> DecomposeResult: pass def _decompose_with_context_( - self, qubits: tuple[cirq.Qid, ...], *, context: DecompositionContext | None = None + self, qubits: tuple[cirq.Qid, ...], *, context: DecompositionContext ) -> DecomposeResult: pass def _try_op_decomposer( - val: Any, decomposer: OpDecomposer | None, *, context: DecompositionContext | None = None + val: Any, decomposer: OpDecomposer | None, *, context: DecompositionContext ) -> DecomposeResult: if decomposer is None or not isinstance(val, ops.Operation): return None @@ -173,7 +173,7 @@ def _try_op_decomposer( @dataclasses.dataclass(frozen=True) class _DecomposeArgs: - context: DecompositionContext | None + context: DecompositionContext intercepting_decomposer: OpDecomposer | None fallback_decomposer: OpDecomposer | None keep: Callable[[cirq.Operation], bool] | None @@ -362,14 +362,9 @@ def decompose_once( TypeError: `val` didn't have a `_decompose_` method (or that method returned `NotImplemented` or `None`) and `default` wasn't set. """ - if context is None: - context = DecompositionContext( - ops.SimpleQubitManager(prefix=f'_decompose_protocol_{next(_CONTEXT_COUNTER)}') - ) - - method = getattr(val, '_decompose_with_context_', None) - decomposed = NotImplemented if method is None else method(*args, **kwargs, context=context) - if decomposed is NotImplemented or decomposed is None: + if context is not None and hasattr(val, '_decompose_with_context_'): + decomposed = val._decompose_with_context_(*args, context=context, **kwargs) + else: method = getattr(val, '_decompose_', None) decomposed = NotImplemented if method is None else method(*args, **kwargs)