Skip to content

Adding tags to Moments #7467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 48 additions & 7 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Any,
Callable,
cast,
Hashable,
Iterable,
Iterator,
Mapping,
Expand Down Expand Up @@ -77,7 +78,12 @@ class Moment:
are no such operations, returns an empty Moment.
"""

def __init__(self, *contents: cirq.OP_TREE, _flatten_contents: bool = True) -> None:
def __init__(
self,
*contents: cirq.OP_TREE,
_flatten_contents: bool = True,
tags: tuple[Hashable, ...] = (),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a docstring for tags or is it intentionally hidden?

) -> None:
"""Constructs a moment with the given operations.

Args:
Expand Down Expand Up @@ -110,6 +116,7 @@ def __init__(self, *contents: cirq.OP_TREE, _flatten_contents: bool = True) -> N

self._measurement_key_objs: frozenset[cirq.MeasurementKey] | None = None
self._control_keys: frozenset[cirq.MeasurementKey] | None = None
self._tags = tags

@classmethod
def from_ops(cls, *ops: cirq.Operation) -> cirq.Moment:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us add an optional tags argument to from_ops too.

Expand All @@ -133,6 +140,32 @@ def operations(self) -> tuple[cirq.Operation, ...]:
def qubits(self) -> frozenset[cirq.Qid]:
return frozenset(self._qubit_to_op)

@property
def tags(self) -> tuple[Hashable, ...]:
"""Returns a tuple of the operation's tags."""
return self._tags

def with_tags(self, *new_tags: Hashable) -> cirq.Moment:
"""Creates a new Moment with the current ops and the specified tags.

If the moment already has tags, this will add the new_tags to the
preexisting tags.

This method can be used to attach meta-data to moments
without affecting their functionality. The intended usage is to
attach classes intended for this purpose or strings to mark operations
for specific usage that will be recognized by consumers.

Tags can be a list of any type of object that is useful to identify
this operation as long as the type is hashable. If you wish the
resulting operation to be eventually serialized into JSON, you should
also restrict the operation to be JSON serializable.

Please note that tags should be instantiated if classes are
used. Raw types are not allowed.
"""
return Moment(*self._operations, _flatten_contents=False, tags=(*self._tags, *new_tags))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - consider adding a shortcut if not new_tags: return self


def operates_on_single_qubit(self, qubit: cirq.Qid) -> bool:
"""Determines if the moment has operations touching the given qubit.
Args:
Expand Down Expand Up @@ -183,7 +216,7 @@ def with_operation(self, operation: cirq.Operation) -> cirq.Moment:
raise ValueError(f'Overlapping operations: {operation}')

# Use private variables to facilitate a quick copy.
m = Moment(_flatten_contents=False)
m = Moment(_flatten_contents=False, tags=self._tags)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description says tags are lost if the moment undergoes transformation, but here tags are retained. Either way is probably fine, but wanted to make sure this change was intentional.

Copy link
Collaborator

@pavoljuhas pavoljuhas Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 - perhaps it would be simpler to remove tags anytime we return a new Moment with changed operations. This should apply to the add / sub operators below as well. A way to think of it is that tags were given to a moment with some desired behavior / operations. If that changes, the tags should be dropped.

Currently there is an inconsistency that with_operation preserves tags, but without_operations_touching removes them.

Or is there some requirement that tags should be sticky and survive some Moment transformations?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on definition of "simpler". If Moment was a dataclass, which I think there's a decent argument that it should be, it'd be simpler to leave them.

That said, I think it's understandable if the with_ methods here leave them, but transformers remove them (which I think is what the PR description was actually referring to), since transformers reorganize the whole circuit in ways that moments are not preserved.

So, I don't have much of a preference whether the with_ methods here preserve tags or not (okay maybe a slight preference that they do preserve them--I think it's what users would expect, and the workaround for users to re-add tags after each with_ if we're not preserving them, is more troublesome than the workaround for the opposite, which they could do via moment.with_whatever(...).without_tags()). But either way is fine. The important thing is that whichever way we go, the methods on the Moment class itself should handle them consistently.

m._operations = self._operations + (operation,)
m._sorted_operations = None
m._qubit_to_op = {**self._qubit_to_op, **{q: operation for q in operation.qubits}}
Expand Down Expand Up @@ -212,7 +245,7 @@ def with_operations(self, *contents: cirq.OP_TREE) -> cirq.Moment:
if not flattened_contents:
return self

m = Moment(_flatten_contents=False)
m = Moment(_flatten_contents=False, tags=self._tags)
# Use private variables to facilitate a quick copy.
m._qubit_to_op = self._qubit_to_op.copy()
for op in flattened_contents:
Expand Down Expand Up @@ -510,18 +543,26 @@ def _superoperator_(self) -> np.ndarray:
return qis.kraus_to_superoperator(self._kraus_())

def _json_dict_(self) -> dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['operations'])
# For backwards compatibility, only output tags if they exist.
args = ['operations', 'tags'] if self._tags else ['operations']
return protocols.obj_to_dict_helper(self, args)

@classmethod
def _from_json_dict_(cls, operations, **kwargs):
return cls.from_ops(*operations)
def _from_json_dict_(cls, operations, tags=(), **kwargs):
return cls(*operations, tags=tags)

def __add__(self, other: cirq.OP_TREE) -> cirq.Moment:
if isinstance(other, circuit.AbstractCircuit):
return NotImplemented # Delegate to Circuit.__radd__.
if isinstance(other, Moment):
return self.with_tags(*other.tags).with_operations(other)
return self.with_operations(other)

def __sub__(self, other: cirq.OP_TREE) -> cirq.Moment:
if isinstance(other, Moment):
new_tags = tuple(tag for tag in self._tags if tag not in other.tags)
else:
new_tags = self._tags
must_remove = set(op_tree.flatten_to_ops(other))
new_ops = []
for op in self.operations:
Expand All @@ -535,7 +576,7 @@ def __sub__(self, other: cirq.OP_TREE) -> cirq.Moment:
f"Missing operations: {must_remove!r}\n"
f"Moment: {self!r}"
)
return Moment(new_ops)
return Moment(new_ops, tags=new_tags)

@overload
def __getitem__(self, key: raw_types.Qid) -> cirq.Operation:
Expand Down
92 changes: 92 additions & 0 deletions cirq-core/cirq/circuits/moment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,3 +929,95 @@ def test_superoperator():
assert m._has_superoperator_()
s = m._superoperator_()
assert np.allclose(s, np.array([[1, 0, 0, 1], [0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 0, 1]]) / 2)


def test_moment_with_tags() -> None:
q0 = cirq.LineQubit(0)
q1 = cirq.LineQubit(1)
op1 = cirq.X(q0)
op2 = cirq.Y(q1)

# Test initialization with no tags
moment_no_tags = cirq.Moment(op1)
assert moment_no_tags.tags == ()

# Test initialization with tags
moment_with_tags = cirq.Moment(op1, op2, tags=("initial_tag_1", "initial_tag_2"))
assert moment_with_tags.tags == ("initial_tag_1", "initial_tag_2")

# Test with_tags method to add new tags
new_moment = moment_with_tags.with_tags("new_tag_1", "new_tag_2")

# Ensure the original moment's tags are unchanged
assert moment_with_tags.tags == ("initial_tag_1", "initial_tag_2")

# Ensure the new moment has both old and new tags
assert new_moment.tags == ("initial_tag_1", "initial_tag_2", "new_tag_1", "new_tag_2")

# Test with_tags on a moment that initially had no tags
new_moment_from_no_tags = moment_no_tags.with_tags("single_new_tag")
assert new_moment_from_no_tags.tags == ("single_new_tag",)

# Test adding no new tags
same_moment_tags = moment_with_tags.with_tags()
assert same_moment_tags.tags == ("initial_tag_1", "initial_tag_2")

class CustomTag:
"""Example Hashable Tag"""

def __init__(self, value):
self.value = value

def __hash__(self):
return hash(self.value) # pragma: nocover

def __eq__(self, other):
return isinstance(other, CustomTag) and self.value == other.value # pragma: nocover

def __repr__(self):
return f"CustomTag({self.value})" # pragma: nocover

tag_obj = CustomTag("complex_tag")
moment_with_custom_tag = cirq.Moment(op1, tags=("string_tag", 123, tag_obj))
assert moment_with_custom_tag.tags == ("string_tag", 123, tag_obj)

new_moment_with_custom_tag = moment_with_custom_tag.with_tags(456)
assert new_moment_with_custom_tag.tags == ("string_tag", 123, tag_obj, 456)


def test_adding_moments_with_tags() -> None:
q0, q1, q2 = cirq.LineQubit.range(3)
op1 = cirq.X(q0)
op2 = cirq.Y(q1)
op3 = cirq.Z(q2)

moment_a = cirq.Moment(op1, tags=("tag_a1", "tag_a2"))
moment_b = cirq.Moment(op2, op3, tags=("tag_b1", "tag_b2"))

combined_moment = moment_a + moment_b
assert combined_moment.operations == (op1, op2, op3)
assert combined_moment.tags == ("tag_a1", "tag_a2", "tag_b1", "tag_b2")

# Test adding a moment to a moment with no tags
moment_c = cirq.Moment(op1)
moment_d = cirq.Moment(op2, tags=("tag_d1",))
combined_no_tags = moment_c + moment_d
assert combined_no_tags.tags == ("tag_d1",)

# Test adding a moment with no tags to a moment
moment_e = cirq.Moment(op1, tags=("tag_e1",))
moment_f = cirq.Moment(op2)
combined_no_tags_reversed = moment_e + moment_f
assert combined_no_tags_reversed.tags == ("tag_e1",)


def test_subtracting_moments_with_tags() -> None:
q0, q1, q2 = cirq.LineQubit.range(3)
op1 = cirq.X(q0)
op2 = cirq.Y(q1)
moment_a = cirq.Moment(op1, op2, tags=("tag_a1", "tag_a2"))
moment_b = cirq.Moment(op2, tags=("tag_a1", "tag_b2"))

subtracted_moment = moment_a - moment_b
assert subtracted_moment.operations == (op1,)
assert subtracted_moment.tags == ("tag_a2",)
23 changes: 22 additions & 1 deletion cirq-core/cirq/protocols/json_test_data/Moment.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,26 @@
}
}
]
},
{
"cirq_type": "Moment",
"operations": [
{
"cirq_type": "SingleQubitPauliStringGateOperation",
"pauli": {
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
},
"qubit": {
"cirq_type": "LineQubit",
"x": 0
}
}
],
"tags": {
"cirq_type": "Duration",
"picos": 25000
}
}
]
]
7 changes: 6 additions & 1 deletion cirq-core/cirq/protocols/json_test_data/Moment.repr
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@
cirq.X(cirq.LineQubit(0)),
cirq.Y(cirq.LineQubit(1)),
cirq.Z(cirq.LineQubit(2)),
)]
),
cirq.Moment(
cirq.X(cirq.LineQubit(0)),
tags=cirq.Duration(nanos=25)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not fit the declared tuple[Hashable] type.
Please fix here and in the example json.

)
]