Skip to content

Fix #853: improve formatter for pretty-printing musical score diagrams #1658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion qualtran/bloqs/rotations/quantum_variable_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,17 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str
else self.cost_reg.total_bits()
)
eps = self.eps / num_rotations

if self.cost_dtype.signed:
out[0] = bb.add(ZPowGate(exponent=1, eps=eps), q=out[0])

offset = 1 + self.cost_dtype.num_frac - self.num_frac_rotations
for i in range(num_rotations):
power_of_two = i - self.num_frac_rotations
power_of_two = i - self.cost_dtype.num_frac
exp = (2**power_of_two) * self.gamma * 2
offset = (
offset if isinstance(offset, int) else 1
) # offset -> 1 if not int to avoid indexing errors
out[-(i + offset)] = bb.add(ZPowGate(exponent=exp, eps=eps), q=out[-(i + offset)])
return {self.cost_reg.name: bb.join(out, self.cost_dtype)}

Expand Down
3 changes: 2 additions & 1 deletion qualtran/drawing/_show_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def show_bloq(bloq: 'Bloq', type: str = 'graph'): # pylint: disable=redefined-b
elif type.lower() == 'dtype':
IPython.display.display(TypedGraphDrawer(bloq).get_svg())
elif type.lower() == 'musical_score':
draw_musical_score(get_musical_score_data(bloq))
msd = get_musical_score_data(bloq)
draw_musical_score(msd, pretty_print=False)
elif type.lower() == 'latex':
show_bloq_via_qpic(bloq)
else:
Expand Down
88 changes: 87 additions & 1 deletion qualtran/drawing/musical_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import abc
import heapq
import json
import re
from enum import Enum
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Set, Tuple, Union

import attrs
import networkx as nx
import numpy as np
import sympy
from attrs import frozen, mutable
from matplotlib import pyplot as plt
from numpy.typing import NDArray
Expand Down Expand Up @@ -688,7 +690,9 @@ def draw_musical_score(
unit_to_inches: float = 0.8,
max_width: float = 10.0,
max_height: float = 8.0,
pretty_print: bool = False,
):

# First, set up data coordinate limits and figure size.
# X coordinates go from -1 to max_x
# with 1 unit of padding it goes from -2 to max_x+1
Expand Down Expand Up @@ -727,7 +731,18 @@ def draw_musical_score(
vline.label.draw(ax, vline.x, vline.bottom_y - 0.5)

for soq in msd.soqs:
symb = soq.symb
new_soq = soq
if pretty_print and isinstance(soq.symb, TextBox):
# Beautify text items if pretty_print is enabled and is TextBox
try:
pretty_text = pretty_format_soq_text(soq.symb.text)
except (ValueError, NameError, TypeError, sympy.SympifyError) as e:
pretty_text = soq.symb.text

# Build new soq
new_soq = soq.__class__(symb=TextBox(text=pretty_text), rpos=soq.rpos, ident=soq.ident)

symb = new_soq.symb
symb.draw(ax, soq.rpos.seq_x, soq.rpos.y)

ax.set_xlim(xlim)
Expand All @@ -750,3 +765,74 @@ def default(self, o: Any) -> Any:
def dump_musical_score(msd: MusicalScoreData, name: str):
with open(f'{name}.json', 'w') as f:
json.dump(msd, f, indent=2, cls=MusicalScoreEncoder)


def pretty_format_soq_text(soq_text: str) -> str:
"""
Evaluates a single soq.symb.text item returning prettiest expression possible or original text if beautification fails.

Args:
soq_text: A raw soq text (soq.symb.text)

Returns:
pretty_soq_text: A pretty soq text

"""

def symbols_in_soq(raw_text: str) -> str:
"""
Identifies the symbols in the soq text based on existence of an substring "(Y)".
Limitations:
Currently identifies single-character symbols A-Z and a-z.

Args:
raw_text: A raw soq string.

Returns:
symbol: A string containing only the symbol as a string or an empty string if no symbol is found

"""

# The pattern searches for "(Y)" (with any character), targeting Abs() section of expressions
pattern = r"\(([a-zA-Z])\)"
match = re.search(pattern, raw_text)

# If no such section is found, the function returns empty string that is used elsewhere as safety check
symbol = match.group(1) if match else ""

return symbol

# Identify any symbols in soq text.
symbol = symbols_in_soq(soq_text)

# If no symbol found: return original soq_text.
# Note. Careful if removing.
# Block is also a security check.
# User can change symbol that reaches this function.
# Simpify is vulnerable to string injection, so only valid symbols should be allowed through.
if symbol == "":
return soq_text

# Simpify locals. They enhance sympify's own ability to evaluate match strings.
# Note. Abs(Y) expressions evaluated as Abs(np.pi) to enable full evaluation of string.
# It appears only very small angles in that part of the expression would make a difference.
simpify_locals = {
"Min": sympy.Min,
"ceiling": sympy.ceiling,
"log2": lambda y: sympy.log(y, 2),
"Abs": lambda y: sympy.Abs(np.pi),
"Y": sympy.Symbol("Y"),
}

# Evaluation
# Eval 1. Get the gate out of the way
gate, expression = soq_text.split("^", 1)

# Eval 2. Evaluate mathematical expression using sympify (enhanced with locals) or default to original soq_text
try:
expression_sympified = sympy.sympify(expression, locals=simpify_locals, evaluate=True)
pretty_text = str(gate) + "^" + str(expression_sympified)
except (sympy.SympifyError, ValueError, NameError, TypeError) as e:
pretty_text = soq_text

return pretty_text
10 changes: 10 additions & 0 deletions qualtran/drawing/musical_score_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from qualtran.bloqs.mcmt import MultiAnd
from qualtran.drawing import dump_musical_score, get_musical_score_data, HLine
from qualtran.drawing.musical_score import pretty_format_soq_text
from qualtran.testing import execute_notebook


Expand All @@ -32,6 +33,15 @@ def test_dump_json(tmp_path):
dump_musical_score(msd, name=f'{tmp_path}/musical_score_example')


def test_pretty_format_soq_text():
soq_text = "Z^2*2**(11 - Min(6, ceiling(log2(6283185307.17959*Abs(Y)))))*Y"
expected_str = "Z^64*Y"
assert expected_str == pretty_format_soq_text(soq_text)
soq_text = "Z^2*Y/2**Min(6, ceiling(log2(6283185307.17959*Abs(Y))))"
expected_str = "Z^Y/32"
assert expected_str == pretty_format_soq_text(soq_text)


@pytest.mark.notebook
def test_notebook():
execute_notebook('musical_score')