diff --git a/cirq-core/cirq/circuits/moment.py b/cirq-core/cirq/circuits/moment.py index 51e669b3aea..f9b2bb5da02 100644 --- a/cirq-core/cirq/circuits/moment.py +++ b/cirq-core/cirq/circuits/moment.py @@ -24,6 +24,7 @@ Any, Callable, cast, + Hashable, Iterable, Iterator, Mapping, @@ -77,7 +78,12 @@ class Moment: are no such operations, returns an empty Moment. """ - def __init__(self, *contents: cirq.OP_TREE, _flatten_contents: bool = True) -> None: + def __init__( + self, + *contents: cirq.OP_TREE, + _flatten_contents: bool = True, + tags: tuple[Hashable, ...] = (), + ) -> None: """Constructs a moment with the given operations. Args: @@ -110,6 +116,7 @@ def __init__(self, *contents: cirq.OP_TREE, _flatten_contents: bool = True) -> N self._measurement_key_objs: frozenset[cirq.MeasurementKey] | None = None self._control_keys: frozenset[cirq.MeasurementKey] | None = None + self._tags = tags @classmethod def from_ops(cls, *ops: cirq.Operation) -> cirq.Moment: @@ -133,6 +140,32 @@ def operations(self) -> tuple[cirq.Operation, ...]: def qubits(self) -> frozenset[cirq.Qid]: return frozenset(self._qubit_to_op) + @property + def tags(self) -> tuple[Hashable, ...]: + """Returns a tuple of the operation's tags.""" + return self._tags + + def with_tags(self, *new_tags: Hashable) -> cirq.Moment: + """Creates a new Moment with the current ops and the specified tags. + + If the moment already has tags, this will add the new_tags to the + preexisting tags. + + This method can be used to attach meta-data to moments + without affecting their functionality. The intended usage is to + attach classes intended for this purpose or strings to mark operations + for specific usage that will be recognized by consumers. + + Tags can be a list of any type of object that is useful to identify + this operation as long as the type is hashable. If you wish the + resulting operation to be eventually serialized into JSON, you should + also restrict the operation to be JSON serializable. + + Please note that tags should be instantiated if classes are + used. Raw types are not allowed. + """ + return Moment(*self._operations, _flatten_contents=False, tags=(*self._tags, *new_tags)) + def operates_on_single_qubit(self, qubit: cirq.Qid) -> bool: """Determines if the moment has operations touching the given qubit. Args: @@ -183,7 +216,7 @@ def with_operation(self, operation: cirq.Operation) -> cirq.Moment: raise ValueError(f'Overlapping operations: {operation}') # Use private variables to facilitate a quick copy. - m = Moment(_flatten_contents=False) + m = Moment(_flatten_contents=False, tags=self._tags) m._operations = self._operations + (operation,) m._sorted_operations = None m._qubit_to_op = {**self._qubit_to_op, **{q: operation for q in operation.qubits}} @@ -212,7 +245,7 @@ def with_operations(self, *contents: cirq.OP_TREE) -> cirq.Moment: if not flattened_contents: return self - m = Moment(_flatten_contents=False) + m = Moment(_flatten_contents=False, tags=self._tags) # Use private variables to facilitate a quick copy. m._qubit_to_op = self._qubit_to_op.copy() for op in flattened_contents: @@ -510,18 +543,26 @@ def _superoperator_(self) -> np.ndarray: return qis.kraus_to_superoperator(self._kraus_()) def _json_dict_(self) -> dict[str, Any]: - return protocols.obj_to_dict_helper(self, ['operations']) + # For backwards compatibility, only output tags if they exist. + args = ['operations', 'tags'] if self._tags else ['operations'] + return protocols.obj_to_dict_helper(self, args) @classmethod - def _from_json_dict_(cls, operations, **kwargs): - return cls.from_ops(*operations) + def _from_json_dict_(cls, operations, tags=(), **kwargs): + return cls(*operations, tags=tags) def __add__(self, other: cirq.OP_TREE) -> cirq.Moment: if isinstance(other, circuit.AbstractCircuit): return NotImplemented # Delegate to Circuit.__radd__. + if isinstance(other, Moment): + return self.with_tags(*other.tags).with_operations(other) return self.with_operations(other) def __sub__(self, other: cirq.OP_TREE) -> cirq.Moment: + if isinstance(other, Moment): + new_tags = tuple(tag for tag in self._tags if tag not in other.tags) + else: + new_tags = self._tags must_remove = set(op_tree.flatten_to_ops(other)) new_ops = [] for op in self.operations: @@ -535,7 +576,7 @@ def __sub__(self, other: cirq.OP_TREE) -> cirq.Moment: f"Missing operations: {must_remove!r}\n" f"Moment: {self!r}" ) - return Moment(new_ops) + return Moment(new_ops, tags=new_tags) @overload def __getitem__(self, key: raw_types.Qid) -> cirq.Operation: diff --git a/cirq-core/cirq/circuits/moment_test.py b/cirq-core/cirq/circuits/moment_test.py index cd0eb856c88..55e29fbbd69 100644 --- a/cirq-core/cirq/circuits/moment_test.py +++ b/cirq-core/cirq/circuits/moment_test.py @@ -929,3 +929,95 @@ def test_superoperator(): assert m._has_superoperator_() s = m._superoperator_() assert np.allclose(s, np.array([[1, 0, 0, 1], [0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 0, 1]]) / 2) + + +def test_moment_with_tags() -> None: + q0 = cirq.LineQubit(0) + q1 = cirq.LineQubit(1) + op1 = cirq.X(q0) + op2 = cirq.Y(q1) + + # Test initialization with no tags + moment_no_tags = cirq.Moment(op1) + assert moment_no_tags.tags == () + + # Test initialization with tags + moment_with_tags = cirq.Moment(op1, op2, tags=("initial_tag_1", "initial_tag_2")) + assert moment_with_tags.tags == ("initial_tag_1", "initial_tag_2") + + # Test with_tags method to add new tags + new_moment = moment_with_tags.with_tags("new_tag_1", "new_tag_2") + + # Ensure the original moment's tags are unchanged + assert moment_with_tags.tags == ("initial_tag_1", "initial_tag_2") + + # Ensure the new moment has both old and new tags + assert new_moment.tags == ("initial_tag_1", "initial_tag_2", "new_tag_1", "new_tag_2") + + # Test with_tags on a moment that initially had no tags + new_moment_from_no_tags = moment_no_tags.with_tags("single_new_tag") + assert new_moment_from_no_tags.tags == ("single_new_tag",) + + # Test adding no new tags + same_moment_tags = moment_with_tags.with_tags() + assert same_moment_tags.tags == ("initial_tag_1", "initial_tag_2") + + class CustomTag: + """Example Hashable Tag""" + + def __init__(self, value): + self.value = value + + def __hash__(self): + return hash(self.value) # pragma: nocover + + def __eq__(self, other): + return isinstance(other, CustomTag) and self.value == other.value # pragma: nocover + + def __repr__(self): + return f"CustomTag({self.value})" # pragma: nocover + + tag_obj = CustomTag("complex_tag") + moment_with_custom_tag = cirq.Moment(op1, tags=("string_tag", 123, tag_obj)) + assert moment_with_custom_tag.tags == ("string_tag", 123, tag_obj) + + new_moment_with_custom_tag = moment_with_custom_tag.with_tags(456) + assert new_moment_with_custom_tag.tags == ("string_tag", 123, tag_obj, 456) + + +def test_adding_moments_with_tags() -> None: + q0, q1, q2 = cirq.LineQubit.range(3) + op1 = cirq.X(q0) + op2 = cirq.Y(q1) + op3 = cirq.Z(q2) + + moment_a = cirq.Moment(op1, tags=("tag_a1", "tag_a2")) + moment_b = cirq.Moment(op2, op3, tags=("tag_b1", "tag_b2")) + + combined_moment = moment_a + moment_b + assert combined_moment.operations == (op1, op2, op3) + assert combined_moment.tags == ("tag_a1", "tag_a2", "tag_b1", "tag_b2") + + # Test adding a moment to a moment with no tags + moment_c = cirq.Moment(op1) + moment_d = cirq.Moment(op2, tags=("tag_d1",)) + combined_no_tags = moment_c + moment_d + assert combined_no_tags.tags == ("tag_d1",) + + # Test adding a moment with no tags to a moment + moment_e = cirq.Moment(op1, tags=("tag_e1",)) + moment_f = cirq.Moment(op2) + combined_no_tags_reversed = moment_e + moment_f + assert combined_no_tags_reversed.tags == ("tag_e1",) + + +def test_subtracting_moments_with_tags() -> None: + q0, q1, q2 = cirq.LineQubit.range(3) + op1 = cirq.X(q0) + op2 = cirq.Y(q1) + moment_a = cirq.Moment(op1, op2, tags=("tag_a1", "tag_a2")) + moment_b = cirq.Moment(op2, tags=("tag_a1", "tag_b2")) + + subtracted_moment = moment_a - moment_b + assert subtracted_moment.operations == (op1,) + assert subtracted_moment.tags == ("tag_a2",) diff --git a/cirq-core/cirq/protocols/json_test_data/Moment.json b/cirq-core/cirq/protocols/json_test_data/Moment.json index f4a4d6d8f23..768d04d7504 100644 --- a/cirq-core/cirq/protocols/json_test_data/Moment.json +++ b/cirq-core/cirq/protocols/json_test_data/Moment.json @@ -39,5 +39,26 @@ } } ] + }, + { + "cirq_type": "Moment", + "operations": [ + { + "cirq_type": "SingleQubitPauliStringGateOperation", + "pauli": { + "cirq_type": "_PauliX", + "exponent": 1.0, + "global_shift": 0.0 + }, + "qubit": { + "cirq_type": "LineQubit", + "x": 0 + } + } + ], + "tags": { + "cirq_type": "Duration", + "picos": 25000 + } } -] \ No newline at end of file +] diff --git a/cirq-core/cirq/protocols/json_test_data/Moment.repr b/cirq-core/cirq/protocols/json_test_data/Moment.repr index 7629c2aa8eb..8c0680e4bb9 100644 --- a/cirq-core/cirq/protocols/json_test_data/Moment.repr +++ b/cirq-core/cirq/protocols/json_test_data/Moment.repr @@ -2,4 +2,9 @@ cirq.X(cirq.LineQubit(0)), cirq.Y(cirq.LineQubit(1)), cirq.Z(cirq.LineQubit(2)), -)] \ No newline at end of file +), +cirq.Moment( + cirq.X(cirq.LineQubit(0)), + tags=cirq.Duration(nanos=25) +) +]