Skip to content

Commit 827e6d1

Browse files
committed
remove W array from s_mp_mul_comba and s_mp_sqr_comba
remove calls to comba from s_mp_mul and s_mp_mul_high TODO: * Remove remaining W arrays * Replace mp_exch/mp_clear pairs by mp_clear/copy * Check if more mp_init* calls can be replaced by MP_ALIAS/mp_init_size/mp_grow optimization
1 parent cc77fad commit 827e6d1

File tree

11 files changed

+122
-115
lines changed

11 files changed

+122
-115
lines changed

etc/tune.c

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,15 @@ static int s_number_of_test_loops;
5858
static int s_stabilization_extra;
5959
static int s_offset = 1;
6060

61-
#define s_mp_mul_full(a, b, c) s_mp_mul(a, b, c, (a)->used + (b)->used + 1)
61+
static mp_err s_mul_full(const mp_int *a, const mp_int *b, mp_int *c)
62+
{
63+
if (MP_HAS(S_MP_MUL_HIGH_COMBA)
64+
&& (MP_MIN(a->used, b->used) < MP_MAX_COMBA)) {
65+
return s_mp_mul_comba(a, b, c, a->used + b->used + 1);
66+
}
67+
return s_mp_mul(a, b, c, a->used + b->used + 1);
68+
}
69+
6270
static uint64_t s_time_mul(int size)
6371
{
6472
int x;
@@ -87,7 +95,7 @@ static uint64_t s_time_mul(int size)
8795
goto LBL_ERR;
8896
}
8997
if (s_check_result == 1) {
90-
if ((e = s_mp_mul_full(&a,&b,&d)) != MP_OKAY) {
98+
if ((e = s_mul_full(&a,&b,&d)) != MP_OKAY) {
9199
t1 = UINT64_MAX;
92100
goto LBL_ERR;
93101
}

mp_mul.c

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,7 @@ mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c)
3131
} else if (MP_HAS(S_MP_MUL_KARATSUBA) &&
3232
(min >= MP_MUL_KARATSUBA_CUTOFF)) {
3333
err = s_mp_mul_karatsuba(a, b, c);
34-
} else if (MP_HAS(S_MP_MUL_COMBA) &&
35-
/* can we use the fast multiplier?
36-
*
37-
* The fast multiplier can be used if the output will
38-
* have less than MP_WARRAY digits and the number of
39-
* digits won't affect carry propagation
40-
*/
41-
(digs < MP_WARRAY) &&
34+
} else if (MP_HAS(S_MP_MUL_COMBA) && /* can we use the fast multiplier? */
4235
(min <= MP_MAX_COMBA)) {
4336
err = s_mp_mul_comba(a, b, c, digs);
4437
} else if (MP_HAS(S_MP_MUL)) {

mp_reduce.c

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,11 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu)
2323

2424
/* according to HAC this optimization is ok */
2525
if ((mp_digit)um > ((mp_digit)1 << (MP_DIGIT_BIT - 1))) {
26-
if ((err = mp_mul(&q, mu, &q)) != MP_OKAY) {
27-
goto LBL_ERR;
28-
}
29-
} else if (MP_HAS(S_MP_MUL_HIGH)) {
30-
if ((err = s_mp_mul_high(&q, mu, &q, um)) != MP_OKAY) {
31-
goto LBL_ERR;
32-
}
26+
if ((err = mp_mul(&q, mu, &q)) != MP_OKAY) goto LBL_ERR;
3327
} else if (MP_HAS(S_MP_MUL_HIGH_COMBA)) {
34-
if ((err = s_mp_mul_high_comba(&q, mu, &q, um)) != MP_OKAY) {
35-
goto LBL_ERR;
36-
}
28+
if ((err = s_mp_mul_high_comba(&q, mu, &q, um)) != MP_OKAY) goto LBL_ERR;
29+
} else if (MP_HAS(S_MP_MUL_HIGH)) {
30+
if ((err = s_mp_mul_high(&q, mu, &q, um)) != MP_OKAY) goto LBL_ERR;
3731
} else {
3832
err = MP_VAL;
3933
goto LBL_ERR;
@@ -43,41 +37,33 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu)
4337
mp_rshd(&q, um + 1);
4438

4539
/* x = x mod b**(k+1), quick (no division) */
46-
if ((err = mp_mod_2d(x, MP_DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
47-
goto LBL_ERR;
48-
}
40+
if ((err = mp_mod_2d(x, MP_DIGIT_BIT * (um + 1), x)) != MP_OKAY) goto LBL_ERR;
4941

5042
/* q = q * m mod b**(k+1), quick (no division) */
51-
if ((err = s_mp_mul(&q, m, &q, um + 1)) != MP_OKAY) {
52-
goto LBL_ERR;
43+
if (MP_HAS(S_MP_MUL_COMBA)
44+
&& (MP_MIN(q.used, m->used) < MP_MAX_COMBA)) {
45+
if ((err = s_mp_mul_comba(&q, m, &q, um + 1)) != MP_OKAY) goto LBL_ERR;
46+
} else {
47+
if ((err = s_mp_mul(&q, m, &q, um + 1)) != MP_OKAY) goto LBL_ERR;
5348
}
5449

5550
/* x = x - q */
56-
if ((err = mp_sub(x, &q, x)) != MP_OKAY) {
57-
goto LBL_ERR;
58-
}
51+
if ((err = mp_sub(x, &q, x)) != MP_OKAY) goto LBL_ERR;
5952

6053
/* If x < 0, add b**(k+1) to it */
6154
if (mp_cmp_d(x, 0uL) == MP_LT) {
6255
mp_set(&q, 1uL);
63-
if ((err = mp_lshd(&q, um + 1)) != MP_OKAY) {
64-
goto LBL_ERR;
65-
}
66-
if ((err = mp_add(x, &q, x)) != MP_OKAY) {
67-
goto LBL_ERR;
68-
}
56+
if ((err = mp_lshd(&q, um + 1)) != MP_OKAY) goto LBL_ERR;
57+
if ((err = mp_add(x, &q, x)) != MP_OKAY) goto LBL_ERR;
6958
}
7059

7160
/* Back off if it's too big */
7261
while (mp_cmp(x, m) != MP_LT) {
73-
if ((err = s_mp_sub(x, m, x)) != MP_OKAY) {
74-
goto LBL_ERR;
75-
}
62+
if ((err = s_mp_sub(x, m, x)) != MP_OKAY) goto LBL_ERR;
7663
}
7764

7865
LBL_ERR:
7966
mp_clear(&q);
80-
8167
return err;
8268
}
8369
#endif

mp_sqr.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ mp_err mp_sqr(const mp_int *a, mp_int *b)
1414
(a->used >= MP_SQR_KARATSUBA_CUTOFF)) {
1515
err = s_mp_sqr_karatsuba(a, b);
1616
} else if (MP_HAS(S_MP_SQR_COMBA) && /* can we use the fast comba multiplier? */
17-
(((a->used * 2) + 1) < MP_WARRAY) &&
1817
(a->used < (MP_MAX_COMBA / 2))) {
1918
err = s_mp_sqr_comba(a, b);
2019
} else if (MP_HAS(S_MP_SQR)) {

s_mp_mul.c

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99
*/
1010
mp_err s_mp_mul(const mp_int *a, const mp_int *b, mp_int *c, int digs)
1111
{
12-
mp_int t;
12+
mp_int tmp, *c_;
1313
mp_err err;
1414
int pa, ix;
1515

16-
/* can we use the fast multiplier? */
17-
if ((digs < MP_WARRAY) &&
18-
(MP_MIN(a->used, b->used) < MP_MAX_COMBA)) {
19-
return s_mp_mul_comba(a, b, c, digs);
20-
}
21-
22-
if ((err = mp_init_size(&t, digs)) != MP_OKAY) {
16+
/* prepare the destination */
17+
err = (MP_ALIAS(a, c) || MP_ALIAS(b, c))
18+
? mp_init_size((c_ = &tmp), digs)
19+
: mp_grow((c_ = c), digs);
20+
if (err != MP_OKAY) {
2321
return err;
2422
}
25-
t.used = digs;
23+
24+
s_mp_zero_digs(c_->dp, c_->used);
25+
c_->used = digs;
2626

2727
/* compute the digits of the product directly */
2828
pa = a->used;
@@ -36,26 +36,29 @@ mp_err s_mp_mul(const mp_int *a, const mp_int *b, mp_int *c, int digs)
3636
/* compute the columns of the output and propagate the carry */
3737
for (iy = 0; iy < pb; iy++) {
3838
/* compute the column as a mp_word */
39-
mp_word r = (mp_word)t.dp[ix + iy] +
39+
mp_word r = (mp_word)c_->dp[ix + iy] +
4040
((mp_word)a->dp[ix] * (mp_word)b->dp[iy]) +
4141
(mp_word)u;
4242

4343
/* the new column is the lower part of the result */
44-
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
44+
c_->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
4545

4646
/* get the carry word from the result */
4747
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
4848
}
4949
/* set carry if it is placed below digs */
5050
if ((ix + iy) < digs) {
51-
t.dp[ix + pb] = u;
51+
c_->dp[ix + pb] = u;
5252
}
5353
}
5454

55-
mp_clamp(&t);
56-
mp_exch(&t, c);
55+
mp_clamp(c_);
56+
57+
if (c_ == &tmp) {
58+
mp_clear(c);
59+
*c = *c_;
60+
}
5761

58-
mp_clear(&t);
5962
return MP_OKAY;
6063
}
6164
#endif

s_mp_mul_comba.c

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,22 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs)
2323
{
2424
int oldused, pa, ix;
2525
mp_err err;
26-
mp_digit W[MP_WARRAY];
27-
mp_word _W;
26+
mp_word W;
27+
mp_int tmp, *c_;
2828

29-
/* grow the destination as required */
30-
if ((err = mp_grow(c, digs)) != MP_OKAY) {
29+
/* prepare the destination */
30+
err = (MP_ALIAS(a, c) || MP_ALIAS(b, c))
31+
? mp_init_size((c_ = &tmp), digs)
32+
: mp_grow((c_ = c), digs);
33+
if (err != MP_OKAY) {
3134
return err;
3235
}
3336

3437
/* number of output digits to produce */
3538
pa = MP_MIN(digs, a->used + b->used);
3639

3740
/* clear the carry */
38-
_W = 0;
41+
W = 0;
3942
for (ix = 0; ix < pa; ix++) {
4043
int tx, ty, iy, iz;
4144

@@ -50,29 +53,30 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs)
5053

5154
/* execute loop */
5255
for (iz = 0; iz < iy; ++iz) {
53-
_W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz];
56+
W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz];
5457
}
5558

5659
/* store term */
57-
W[ix] = (mp_digit)_W & MP_MASK;
60+
c_->dp[ix] = (mp_digit)W & MP_MASK;
5861

5962
/* make next carry */
60-
_W = _W >> (mp_word)MP_DIGIT_BIT;
63+
W = W >> (mp_word)MP_DIGIT_BIT;
6164
}
6265

6366
/* setup dest */
64-
oldused = c->used;
65-
c->used = pa;
66-
67-
for (ix = 0; ix < pa; ix++) {
68-
/* now extract the previous digit [below the carry] */
69-
c->dp[ix] = W[ix];
70-
}
67+
oldused = c_->used;
68+
c_->used = pa;
7169

7270
/* clear unused digits [that existed in the old copy of c] */
73-
s_mp_zero_digs(c->dp + c->used, oldused - c->used);
71+
s_mp_zero_digs(c_->dp + c_->used, oldused - c_->used);
72+
73+
mp_clamp(c_);
74+
75+
if (c_ == &tmp) {
76+
mp_clear(c);
77+
*c = *c_;
78+
}
7479

75-
mp_clamp(c);
7680
return MP_OKAY;
7781
}
7882
#endif

s_mp_mul_high.c

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@ mp_err s_mp_mul_high(const mp_int *a, const mp_int *b, mp_int *c, int digs)
1212
int pa, pb, ix;
1313
mp_err err;
1414

15-
/* can we use the fast multiplier? */
16-
if (MP_HAS(S_MP_MUL_HIGH_COMBA)
17-
&& ((a->used + b->used + 1) < MP_WARRAY)
18-
&& (MP_MIN(a->used, b->used) < MP_MAX_COMBA)) {
19-
return s_mp_mul_high_comba(a, b, c, digs);
20-
}
21-
2215
if ((err = mp_init_size(&t, a->used + b->used + 1)) != MP_OKAY) {
2316
return err;
2417
}

s_mp_sqr.c

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,36 @@
66
/* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
77
mp_err s_mp_sqr(const mp_int *a, mp_int *b)
88
{
9-
mp_int t;
9+
mp_int tmp, *b_;
1010
int ix, pa;
1111
mp_err err;
1212

1313
pa = a->used;
14-
if ((err = mp_init_size(&t, (2 * pa) + 1)) != MP_OKAY) {
14+
15+
/* prepare the destination */
16+
err = MP_ALIAS(a, b)
17+
? mp_init_size((b_ = &tmp), (2 * pa) + 1)
18+
: mp_grow((b_ = b), (2 * pa + 1));
19+
if (err != MP_OKAY) {
1520
return err;
1621
}
1722

23+
s_mp_zero_digs(b_->dp, b_->used);
24+
1825
/* default used is maximum possible size */
19-
t.used = (2 * pa) + 1;
26+
b_->used = (2 * pa) + 1;
2027

2128
for (ix = 0; ix < pa; ix++) {
2229
mp_digit u;
2330
int iy;
2431

2532
/* first calculate the digit at 2*ix */
2633
/* calculate double precision result */
27-
mp_word r = (mp_word)t.dp[2*ix] +
34+
mp_word r = (mp_word)b_->dp[2*ix] +
2835
((mp_word)a->dp[ix] * (mp_word)a->dp[ix]);
2936

3037
/* store lower part in result */
31-
t.dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK);
38+
b_->dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK);
3239

3340
/* get the carry */
3441
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
@@ -40,26 +47,30 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b)
4047
/* now calculate the double precision result, note we use
4148
* addition instead of *2 since it's easier to optimize
4249
*/
43-
r = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u;
50+
r = (mp_word)b_->dp[ix + iy] + r + r + (mp_word)u;
4451

4552
/* store lower part */
46-
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
53+
b_->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
4754

4855
/* get carry */
4956
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
5057
}
5158
/* propagate upwards */
5259
while (u != 0uL) {
53-
r = (mp_word)t.dp[ix + iy] + (mp_word)u;
54-
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
60+
r = (mp_word)b_->dp[ix + iy] + (mp_word)u;
61+
b_->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
5562
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
5663
++iy;
5764
}
5865
}
5966

60-
mp_clamp(&t);
61-
mp_exch(&t, b);
62-
mp_clear(&t);
67+
mp_clamp(b_);
68+
69+
if (b_ == &tmp) {
70+
mp_clear(b);
71+
*b = *b_;
72+
}
73+
6374
return MP_OKAY;
6475
}
6576
#endif

0 commit comments

Comments
 (0)