diff --git a/qualtran/bloqs/factoring/rsa/rsa_mod_exp.py b/qualtran/bloqs/factoring/rsa/rsa_mod_exp.py index fda33f5fc..1474b6e4c 100644 --- a/qualtran/bloqs/factoring/rsa/rsa_mod_exp.py +++ b/qualtran/bloqs/factoring/rsa/rsa_mod_exp.py @@ -33,14 +33,19 @@ SoquetT, ) from qualtran._infra.registers import Side +from qualtran.bloqs.arithmetic import Add +from qualtran.bloqs.arithmetic.subtraction import SubtractFrom +from qualtran.bloqs.basic_gates.swap import Swap from qualtran.bloqs.basic_gates.z_basis import IntState +from qualtran.bloqs.data_loading.qroam_clean import QROAMClean from qualtran.bloqs.mod_arithmetic import CModMulK +from qualtran.bloqs.mod_arithmetic.mod_addition import ModAdd +from qualtran.bloqs.mod_arithmetic.mod_subtraction import ModSub from qualtran.drawing import Text, WireSymbol from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator from qualtran.resource_counting.generalizers import ignore_split_join from qualtran.simulation.classical_sim import ClassicalValT -from qualtran.symbolics import is_symbolic -from qualtran.symbolics.types import SymbolicInt +from qualtran.symbolics import is_symbolic, Shaped, SymbolicInt @frozen @@ -54,10 +59,13 @@ class ModExp(Bloq): This bloq decomposes into controlled modular exponentiation for each exponent bit. Args: - base: The integer base of the exponentiation - mod: The integer modulus - exp_bitsize: The size of the `exponent` thru-register - x_bitsize: The size of the `x` right-register + base: The integer base of the exponentiation. + mod: The integer modulus. + exp_bitsize: The size of the `exponent` thru-register. + x_bitsize: The size of the `x` right-register. + exp_window_size: The window size of windowed arithmetic on the controlled modular + multiplications. + mult_window_size: The window size of windowed arithmetic on the modular product additions. Registers: exponent: The exponent @@ -66,12 +74,20 @@ class ModExp(Bloq): References: [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). Gidney and EkerÄ. 2019. + + [Circuit for Shor's algorithm using 2n+3 qubits](https://arxiv.org/abs/quant-ph/0205095). + Stephane Beauregard. 2003. + + [Windowed quantum arithmetic](https://arxiv.org/abs/1905.07682). + Craig Gidney. 2019. """ base: 'SymbolicInt' mod: 'SymbolicInt' exp_bitsize: 'SymbolicInt' x_bitsize: 'SymbolicInt' + exp_window_size: Optional['SymbolicInt'] = None + mult_window_size: Optional['SymbolicInt'] = None def __attrs_post_init__(self): if not is_symbolic(self.base, self.mod): @@ -87,12 +103,7 @@ def signature(self) -> 'Signature': ) @classmethod - def make_for_shor( - cls, - big_n: 'SymbolicInt', - g: Optional['SymbolicInt'] = None, - rs: Optional[np.random.RandomState] = None, - ): + def make_for_shor(cls, big_n: 'SymbolicInt', g: Optional['SymbolicInt'] = None, exp_window_size: Optional['SymbolicInt'] = None, mult_window_size: Optional['SymbolicInt'] = None, rs: Optional[np.random.RandomState] = None): """Factory method that sets up the modular exponentiation for a factoring run. Args: @@ -115,7 +126,29 @@ def make_for_shor( g = rs.randint(2, int(big_n)) if math.gcd(g, int(big_n)) == 1: break - return cls(base=g, mod=big_n, exp_bitsize=2 * little_n, x_bitsize=little_n) + return cls(base=g, mod=big_n, exp_bitsize=2 * little_n, x_bitsize=little_n, exp_window_size=exp_window_size, mult_window_size=mult_window_size) + + def qrom(self, data): + if is_symbolic(self.exp_bitsize) or is_symbolic(self.exp_window_size): + log_block_sizes = None + if is_symbolic(self.exp_bitsize) and not is_symbolic(self.exp_window_size): + # We assume that bitsize is much larger than window_size + log_block_sizes = (0,) + return QROAMClean( + [ + data, + ], + selection_bitsizes=(self.exp_window_size, self.mult_window_size), + target_bitsizes=(self.x_bitsize,), + log_block_sizes=log_block_sizes, + ) + + return QROAMClean( + [data], + selection_bitsizes=(self.exp_window_size, self.mult_window_size), + target_bitsizes=(self.x_bitsize,), + ) + def _CtrlModMul(self, k: 'SymbolicInt'): """Helper method to return a `CModMulK` with attributes forwarded.""" @@ -123,21 +156,102 @@ def _CtrlModMul(self, k: 'SymbolicInt'): def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[str, 'SoquetT']: if is_symbolic(self.exp_bitsize): - raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `exp_bitsize`.") - # https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `exp_bitsize`.") x = bb.add(IntState(val=1, bitsize=self.x_bitsize)) exponent = bb.split(exponent) - base = self.base % self.mod - for j in range(self.exp_bitsize - 1, 0 - 1, -1): - exponent[j], x = bb.add(self._CtrlModMul(k=base), ctrl=exponent[j], x=x) - base = (base * base) % self.mod + if self.exp_window_size is not None and self.mult_window_size is not None: + k = self.base + + a = bb.split(x) + b = bb.add(IntState(val=0, bitsize=self.x_bitsize)) + + ei = np.split(np.array(exponent), self.exp_bitsize // self.exp_window_size) + for i in range(self.exp_bitsize // self.exp_window_size): + kes = [pow(k, 2**i * x_e, self.mod) for x_e in range(2**self.exp_window_size)] + kes_inv = [pow(x_e, -1, self.mod) for x_e in kes] + + mi = np.split(np.array(a), self.x_bitsize // self.mult_window_size) + for j in range(self.x_bitsize // self.mult_window_size): + data = list([(ke * f * 2**j) % self.mod for f in range(2**self.mult_window_size)] for ke in kes) + ei_i = bb.join(ei[(self.exp_bitsize // self.exp_window_size) - i - 1], QUInt((self.exp_window_size))) + mi_i = bb.join(mi[(self.x_bitsize // self.mult_window_size) - j - 1], QUInt((self.mult_window_size))) + ei_i, mi_i, t, *junk = bb.add(self.qrom(data), selection0=ei_i, selection1=mi_i) + t, b = bb.add(ModAdd(self.x_bitsize, self.mod), x=t, y=b) + junk_mapping = {f'junk_target{i}_': junk[i] for i in range(len(junk))} + ei_i, mi_i = bb.add(self.qrom(data).adjoint(), selection0=ei_i, selection1=mi_i, target0_=t, **junk_mapping) + ei[(self.exp_bitsize // self.exp_window_size) - i - 1] = bb.split(ei_i) + mi[(self.x_bitsize // self.mult_window_size) - j - 1] = bb.split(mi_i) + + a = np.concatenate(mi, axis=None) + a = bb.join(a, QUInt(self.x_bitsize)) + + b = bb.split(b) + mi = np.split(np.array(b), self.x_bitsize // self.mult_window_size) + for j in range(self.x_bitsize // self.mult_window_size): + data = list([(ke_inv * f * 2**j) % self.mod for f in range(2**self.mult_window_size)] for ke_inv in kes_inv) + ei_i = bb.join(ei[(self.exp_bitsize // self.exp_window_size) - i - 1], QUInt((self.exp_window_size))) + mi_i = bb.join(mi[(self.x_bitsize // self.mult_window_size) - j - 1], QUInt((self.mult_window_size))) + ei_i, mi_i, t, *junk = bb.add(self.qrom(data), selection0=ei_i, selection1=mi_i) + t, a = bb.add(ModSub(QUInt(self.x_bitsize), self.mod), x=t, y=a) + junk_mapping = {f'junk_target{i}_': junk[i] for i in range(len(junk))} + ei_i, mi_i = bb.add(self.qrom(data).adjoint(), selection0=ei_i, selection1=mi_i, target0_=t, **junk_mapping) + ei[(self.exp_bitsize // self.exp_window_size) - i - 1] = bb.split(ei_i) + mi[(self.x_bitsize // self.mult_window_size) - j - 1] = bb.split(mi_i) + + b = np.concatenate(mi, axis=None) + + b = bb.join(b, QUInt(self.x_bitsize)) + + a, b = bb.add(Swap(self.x_bitsize), x=a, y=b) + + a = bb.split(a) + + x = bb.join(a, QUInt(self.x_bitsize)) + exponent = np.concatenate(ei, axis=None) + bb.free(b, dirty=True) + else: + # https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method + base = self.base % self.mod + for j in range(self.exp_bitsize - 1, 0 - 1, -1): + exponent[j], x = bb.add(self._CtrlModMul(k=base), ctrl=exponent[j], x=x) + base = (base * base) % self.mod return {'exponent': bb.join(exponent, dtype=QUInt(self.exp_bitsize)), 'x': x} def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': - k = ssa.new_symbol('k') - return {self._CtrlModMul(k=k): self.exp_bitsize, IntState(val=1, bitsize=self.x_bitsize): 1} + if self.exp_window_size is not None and self.mult_window_size is not None: + if is_symbolic(self.exp_window_size, self.mult_window_size): + num_iterations = self.exp_bitsize // self.exp_window_size + return {self.qrom(Shaped((2**(self.exp_window_size+self.mult_window_size),))): 1, + self.qrom(Shaped((2**(self.exp_window_size+self.mult_window_size),))).adjoint(): 1, + ModAdd(self.x_bitsize, self.mod): 1, + ModSub(QUInt(self.x_bitsize), self.mod): 1, + IntState(val=1, bitsize=self.x_bitsize): 1, Swap(self.x_bitsize): self.exp_bitsize // self.exp_window_size} + else: + cg = {IntState(val=1, bitsize=self.x_bitsize): 1, Swap(self.x_bitsize): self.exp_bitsize // self.exp_window_size} + + k = self.base + for i in range(self.exp_bitsize // self.exp_window_size): + kes = [pow(k, 2**i * x_e, self.mod) for x_e in range(2**self.exp_window_size)] + kes_inv = [pow(x_e, -1, self.mod) for x_e in kes] + + for j in range(self.x_bitsize // self.mult_window_size): + data = list([(ke * f * 2**j) % self.mod for f in range(2**self.mult_window_size)] for ke in kes) + cg[self.qrom(data)] = cg.get(self.qrom(data), 0) + 1 + cg[ModAdd(self.x_bitsize, self.mod)] = cg.get(ModAdd(self.x_bitsize, self.mod), 0) + 1 + cg[self.qrom(data).adjoint()] = cg.get(self.qrom(data).adjoint(), 0) + 1 + + for j in range(self.x_bitsize // self.mult_window_size): + data = list([(ke_inv * f * 2**j) % self.mod for f in range(2**self.mult_window_size)] for ke_inv in kes_inv) + cg[self.qrom(data)] = cg.get(self.qrom(data), 0) + 1 + cg[ModSub(QUInt(self.x_bitsize), self.mod)] = cg.get(ModSub(QUInt(self.x_bitsize), self.mod), 0) + 1 + cg[self.qrom(data).adjoint()] = cg.get(self.qrom(data).adjoint(), 0) + 1 + + return cg + else: + k = ssa.new_symbol('k') + return {self._CtrlModMul(k=k): self.exp_bitsize, IntState(val=1, bitsize=self.x_bitsize): 1} def on_classical_vals(self, exponent) -> Dict[str, Union['ClassicalValT', sympy.Expr]]: return {'exponent': exponent, 'x': (self.base**exponent) % self.mod} @@ -172,11 +286,23 @@ def _modexp() -> ModExp: return modexp +@bloq_example(generalizer=(ignore_split_join, _generalize_k)) +def _modexp_window() -> ModExp: + modexp_window = ModExp.make_for_shor(big_n=13 * 17, g=9, exp_window_size=8, mult_window_size=4) + return modexp_window + + @bloq_example def _modexp_symb() -> ModExp: g, N, n_e, n_x = sympy.symbols('g N n_e, n_x') modexp_symb = ModExp(base=g, mod=N, exp_bitsize=n_e, x_bitsize=n_x) return modexp_symb +@bloq_example +def _modexp_window_symb() -> ModExp: + g, N, n_e, n_x, w_e, w_m = sympy.symbols('g N n_e, n_x w_e w_m') + modexp_window_symb = ModExp(base=g, mod=N, exp_bitsize=n_e, x_bitsize=n_x, exp_window_size=w_e, mult_window_size=w_m) + return modexp_window_symb + -_RSA_MODEXP_DOC = BloqDocSpec(bloq_cls=ModExp, examples=(_modexp_small, _modexp, _modexp_symb)) +_RSA_MODEXP_DOC = BloqDocSpec(bloq_cls=ModExp, examples=(_modexp_small, _modexp, _modexp_symb, _modexp_window, _modexp_window_symb)) diff --git a/qualtran/bloqs/factoring/rsa/rsa_mod_exp_test.py b/qualtran/bloqs/factoring/rsa/rsa_mod_exp_test.py index e4cabd8c4..c524b803a 100644 --- a/qualtran/bloqs/factoring/rsa/rsa_mod_exp_test.py +++ b/qualtran/bloqs/factoring/rsa/rsa_mod_exp_test.py @@ -21,7 +21,7 @@ from qualtran import Bloq from qualtran.bloqs.bookkeeping import Join, Split -from qualtran.bloqs.factoring.rsa.rsa_mod_exp import _modexp, _modexp_small, _modexp_symb, ModExp +from qualtran.bloqs.factoring.rsa.rsa_mod_exp import _modexp, _modexp_small, _modexp_symb, _modexp_window, _modexp_window_symb, ModExp from qualtran.bloqs.mod_arithmetic import CModMulK from qualtran.drawing import Text from qualtran.resource_counting import SympySymbolAllocator @@ -48,6 +48,40 @@ def test_mod_exp_consistent_classical(): for i in range(len(ret1)): np.testing.assert_array_equal(ret1[i], ret2[i]) +@pytest.mark.parametrize('p', [11, 13]) +def test_mod_exp_window_consistent_classical_fast(p): + bloq = ModExp.make_for_shor(big_n=p, exp_window_size=2, mult_window_size=2) + + rs = np.random.RandomState(52) + n_x = int(np.ceil(np.log2(p))) + + for _ in range(10): + exponent = rs.randint(1, 2**n_x) + + ret1 = bloq.call_classically(exponent=exponent) + ret2 = bloq.decompose_bloq().call_classically(exponent=exponent) + assert len(ret1) == len(ret2) + for i in range(len(ret1)): + np.testing.assert_array_equal(ret1[i], ret2[i]) + +''' +@pytest.mark.slow +@pytest.mark.parametrize('p, w_e, w_m', [(p, w_e, w_m) for p in (7, 11, 13) for w_e in range(1, (2 * int(np.ceil(np.log2(p)))) + 1) if (2 * int(np.ceil(np.log2(p)))) % w_e == 0 for w_m in range(1, int(np.ceil(np.log2(p))) + 1) if int(np.ceil(np.log2(p))) % w_m == 0]) +def test_mod_exp_window_consistent_classical(p, w_e, w_m): + bloq = ModExp.make_for_shor(big_n=p, exp_window_size=w_e, mult_window_size=w_m) + + rs = np.random.RandomState(52) + n_x = int(np.ceil(np.log2(p))) + + for _ in range(10): + exponent = rs.randint(1, 2**n_x) + + ret1 = bloq.call_classically(exponent=exponent) + ret2 = bloq.decompose_bloq().call_classically(exponent=exponent) + assert len(ret1) == len(ret2) + for i in range(len(ret1)): + np.testing.assert_array_equal(ret1[i], ret2[i]) +''' def test_modexp_symb_manual(): g, N, n_e, n_x = sympy.symbols('g N n_e, n_x') @@ -89,7 +123,7 @@ def test_mod_exp_t_complexity(): assert tcomp.t > 0 -@pytest.mark.parametrize('bloq', [_modexp, _modexp_symb, _modexp_small]) +@pytest.mark.parametrize('bloq', [_modexp, _modexp_symb, _modexp_small, _modexp_window, _modexp_window_symb]) def test_modexp(bloq_autotester, bloq): bloq_autotester(bloq) diff --git a/qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py b/qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py index 611af9de4..aeb4c68dc 100644 --- a/qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py +++ b/qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py @@ -51,15 +51,26 @@ class RSAPhaseEstimate(Bloq): n: The bitsize of the modulus N. mod: The modulus N; a part of the public key for RSA. base: A base for modular exponentiation. + exp_window_size: The window size of windowed arithmetic on the controlled modular + multiplications. + mult_window_size: The window size of windowed arithmetic on the modular product additions. References: + [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). + Gidney and EkerÄ. 2019. + [Circuit for Shor's algorithm using 2n+3 qubits](https://arxiv.org/abs/quant-ph/0205095). - Beauregard. 2003. Fig 1. + Stephane Beauregard. 2003. + + [Windowed quantum arithmetic](https://arxiv.org/abs/1905.07682). + Craig Gidney. 2019. """ n: 'SymbolicInt' mod: 'SymbolicInt' base: 'SymbolicInt' + exp_window_size: 'SymbolicInt' = 1 + mult_window_size: 'SymbolicInt' = 1 @cached_property def signature(self) -> 'Signature': diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index 347c78b61..68f03ecd2 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -332,6 +332,7 @@ "qualtran.bloqs.data_loading.qrom.QROM": qualtran.bloqs.data_loading.qrom.QROM, "qualtran.bloqs.data_loading.qroam_clean.QROAMClean": qualtran.bloqs.data_loading.qroam_clean.QROAMClean, "qualtran.bloqs.data_loading.qroam_clean.QROAMCleanAdjoint": qualtran.bloqs.data_loading.qroam_clean.QROAMCleanAdjoint, + "qualtran.bloqs.data_loading.qroam_clean.QROAMCleanAdjointWrapper": qualtran.bloqs.data_loading.qroam_clean.QROAMCleanAdjointWrapper, "qualtran.bloqs.data_loading.select_swap_qrom.SelectSwapQROM": qualtran.bloqs.data_loading.select_swap_qrom.SelectSwapQROM, "qualtran.bloqs.mod_arithmetic.CModAddK": qualtran.bloqs.mod_arithmetic.CModAddK, "qualtran.bloqs.mod_arithmetic.mod_addition.ModAdd": qualtran.bloqs.mod_arithmetic.mod_addition.ModAdd, @@ -340,7 +341,7 @@ "qualtran.bloqs.mod_arithmetic.mod_addition.CModAddK": qualtran.bloqs.mod_arithmetic.mod_addition.CModAddK, "qualtran.bloqs.mod_arithmetic.mod_addition.CtrlScaleModAdd": qualtran.bloqs.mod_arithmetic.CtrlScaleModAdd, "qualtran.bloqs.mod_arithmetic.ModAdd": qualtran.bloqs.mod_arithmetic.ModAdd, - "qualtran.bloqs.mod_arithmetic.ModSub": qualtran.bloqs.mod_arithmetic.ModSub, + "qualtran.bloqs.mod_arithmetic.mod_subtraction.ModSub": qualtran.bloqs.mod_arithmetic.mod_subtraction.ModSub, "qualtran.bloqs.mod_arithmetic.CModSub": qualtran.bloqs.mod_arithmetic.CModSub, "qualtran.bloqs.mod_arithmetic.mod_subtraction.ModNeg": qualtran.bloqs.mod_arithmetic.mod_subtraction.ModNeg, "qualtran.bloqs.mod_arithmetic.mod_subtraction.CModNeg": qualtran.bloqs.mod_arithmetic.mod_subtraction.CModNeg,