Skip to content

Generalize fast sum of powers for any length, not just power-of-two #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 109 additions & 36 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ impl Poly2 {
}

/// Raises `x` to the power `n` using binary exponentiation,
/// with (1 to 2)*lg(n) scalar multiplications.
/// TODO: a consttime version of this would be awfully similar to a Montgomery ladder.
/// with `(1 to 2)*lg(n)` scalar multiplications.
/// TODO: a consttime version of this would be similar to a Montgomery ladder.
pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar {
let mut result = Scalar::one();
let mut aux = *x; // x, x^2, x^4, x^8, ...
Expand All @@ -99,38 +99,104 @@ pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar {
result = result * aux;
}
n = n >> 1;
aux = aux * aux; // FIXME: one unnecessary mult at the last step here!
if n > 0 {
aux = aux * aux;
}
}
result
}

/// Takes the sum of all the powers of `x`, up to `n`
/// If `n` is a power of 2, it uses the efficient algorithm with `2*lg n` multiplications and additions.
/// If `n` is not a power of 2, it uses the slow algorithm with `n` multiplications and additions.
/// In the Bulletproofs case, all calls to `sum_of_powers` should have `n` as a power of 2.
pub fn sum_of_powers(x: &Scalar, n: usize) -> Scalar {
if !n.is_power_of_two() {
return sum_of_powers_slow(x, n);
}
if n == 0 || n == 1 {
return Scalar::from(n as u64);
}
let mut m = n;
let mut result = Scalar::one() + x;
let mut factor = *x;
while m > 2 {
factor = factor * factor;
result = result + factor * result;
m = m / 2;
/// Computes the sum of all the powers of \\(x\\) \\(S(n) = (x^0 + \dots + x^{n-1})\\)
/// using \\(O(\lg n)\\) multiplications. Length \\(n\\) is not considered secret
/// and algorithm is fastest when \\(n\\) is the power of two (\\(2\lg n + 1\\) multiplications).
///
/// ### Algorithm description
///
/// First, let \\(n\\) be a power of two.
/// Then, we can divide the polynomial in two halves like so:
/// \\[
/// \begin{aligned}
/// S(n) &= (1+\dots+x^{n-1}) \\\\
/// &= (1+\dots+x^{n/2-1}) + x^{n/2} (1+\dots+x^{n/2-1}) \\\\
/// &= s + x^{n/2} s.
/// \end{aligned}
/// \\]
/// We can divide each \\(s\\) in half until we arrive to a degree-1 polynomial \\((1+x\cdot 1)\\).
/// Recursively, the total sum can be defined as:
/// \\[
/// \begin{aligned}
/// S(0) &= 0 \\\\
/// S(n) &= s_{\lg n} \\\\
/// s_0 &= 1 \\\\
/// s_i &= s_{i-1} + x^{2^{i-1}} s_{i-1}
/// \end{aligned}
/// \\]
/// This representation allows us to do only \\(2 \cdot \lg n\\) multiplications:
/// squaring \\(x\\) and multiplying it by \\(s_{i-1}\\) at each iteration.
///
/// Lets apply this to \\(n\\) which is not a power of two. The intuition behind the generalized
/// algorithm is to combine all intermediate power-of-two-degree polynomials corresponding to the
/// bits of \\(n\\) that are equal to 1.
///
/// 1. Represent \\(n\\) in binary.
/// 2. For each bit which is set (from the lowest to the highest):
/// 1. Compute a corresponding power-of-two-degree polynomial using the above algorithm.
/// Since we can reuse all intermediate polynomials, this adds no overhead to computing
/// a polynomial for the highest bit.
/// 2. Multiply the polynomial by the next power of \\(x\\), relative to the degree of the
/// already computed result. This effectively _offsets_ the polynomial to a correct range of
/// powers, so it can be added directly with the rest.
/// The next power of \\(x\\) is computed along all the intermediate polynomials,
/// by multiplying it by power-of-two power of \\(x\\) computed in step 2.1.
/// 3. Add to the result.
///
/// (\\(2^{k-1} < n < 2^k\\)) which can be represented in binary using
/// bits \\(b_i\\) in \\(\\{0,1\\}\\):
/// \\[
/// n = b_0 2^0 + \dots + b_{k-1} 2^{k-1}
/// \\]
/// If we scan the bits of \\(n\\) from low to high (\\(i \in [0,k)\\)),
/// we can conditionally (if \\(b_i = 1\\)) add to a resulting scalar
/// an intermediate polynomial with \\(2^i\\) terms using the above algorithm,
/// provided we offset the polynomial by \\(x^{n_i}\\), the next power of \\(x\\)
/// for the existing sum, where \\(n_i = \sum_{j=0}^{i-1} b_j 2^j\\).
///
/// The full algorithm becomes:
/// \\[
/// \begin{aligned}
/// S(0) &= 0 \\\\
/// S(1) &= 1 \\\\
/// S(i) &= S(i-1) + x^{n_i} s_i b_i\\\\
/// &= S(i-1) + x^{n_{i-1}} (x^{2^{i-1}})^{b_{i-1}} s_i b_i
/// \end{aligned}
/// \\]
pub fn sum_of_powers(x: &Scalar, mut n: usize) -> Scalar {
let mut result = Scalar::zero();
let mut f = Scalar::one(); // next-power-of-x to offset subsequent polynomials based on preceding bits of n.
let mut s = Scalar::one(); // power-of-two polynomials: (1, 1+x, 1+x+x^2+x^3, 1+...+x^7, , 1+...+x^15, ...)
let mut p = *x; // power-of-two powers of x: (x, x^2, x^4, ..., x^{2^i})
while n > 0 {
// take a bit from n
let bit = n & 1;
n = n >> 1;

if bit == 1 {
// `n` is not secret, so it's okay to be vartime on bits of `n`.
result += f * s;
if n > 0 {
// avoid multiplication if no bits left
f = f * p;
}
}
if n > 0 {
// avoid multiplication if no bits left
s = s + p * s;
p = p * p;
}
}
result
}

// takes the sum of all of the powers of x, up to n
fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar {
exp_iter(*x).take(n).sum()
}

/// Given `data` with `len >= 32`, return the first 32 bytes.
pub fn read32(data: &[u8]) -> [u8; 32] {
let mut buf32 = [0u8; 32];
Expand Down Expand Up @@ -196,9 +262,14 @@ mod tests {
);
}

// takes the sum of all of the powers of x, up to n
fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar {
exp_iter(*x).take(n).fold(Scalar::zero(), |acc, x| acc + x)
}

#[test]
fn test_sum_of_powers() {
let x = Scalar::from(10u64);
fn test_sum_of_powers_pow2() {
let x = Scalar::from(1337133713371337u64);
assert_eq!(sum_of_powers_slow(&x, 0), sum_of_powers(&x, 0));
assert_eq!(sum_of_powers_slow(&x, 1), sum_of_powers(&x, 1));
assert_eq!(sum_of_powers_slow(&x, 2), sum_of_powers(&x, 2));
Expand All @@ -210,14 +281,16 @@ mod tests {
}

#[test]
fn test_sum_of_powers_slow() {
fn test_sum_of_powers_non_pow2() {
let x = Scalar::from(10u64);
assert_eq!(sum_of_powers_slow(&x, 0), Scalar::zero());
assert_eq!(sum_of_powers_slow(&x, 1), Scalar::one());
assert_eq!(sum_of_powers_slow(&x, 2), Scalar::from(11u64));
assert_eq!(sum_of_powers_slow(&x, 3), Scalar::from(111u64));
assert_eq!(sum_of_powers_slow(&x, 4), Scalar::from(1111u64));
assert_eq!(sum_of_powers_slow(&x, 5), Scalar::from(11111u64));
assert_eq!(sum_of_powers_slow(&x, 6), Scalar::from(111111u64));
assert_eq!(sum_of_powers(&x, 0), Scalar::zero());
assert_eq!(sum_of_powers(&x, 1), Scalar::one());
assert_eq!(sum_of_powers(&x, 2), Scalar::from(11u64));
assert_eq!(sum_of_powers(&x, 3), Scalar::from(111u64));
assert_eq!(sum_of_powers(&x, 4), Scalar::from(1111u64));
assert_eq!(sum_of_powers(&x, 5), Scalar::from(11111u64));
assert_eq!(sum_of_powers(&x, 6), Scalar::from(111111u64));
assert_eq!(sum_of_powers(&x, 7), Scalar::from(1111111u64));
assert_eq!(sum_of_powers(&x, 8), Scalar::from(11111111u64));
}
}