| #include "os.h" |
| #include <mp.h> |
| #include "dat.h" |
| |
| /* */ |
| /* from knuth's 1969 seminumberical algorithms, pp 233-235 and pp 258-260 */ |
| /* */ |
| /* mpvecmul is an assembly language routine that performs the inner */ |
| /* loop. */ |
| /* */ |
| /* the karatsuba trade off is set empiricly by measuring the algs on */ |
| /* a 400 MHz Pentium II. */ |
| /* */ |
| |
| /* karatsuba like (see knuth pg 258) */ |
| /* prereq: p is already zeroed */ |
| static void |
| mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) |
| { |
| mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod; |
| int u0len, u1len, v0len, v1len, reslen; |
| int sign, n; |
| |
| /* divide each piece in half */ |
| n = alen/2; |
| if(alen&1) |
| n++; |
| u0len = n; |
| u1len = alen-n; |
| if(blen > n){ |
| v0len = n; |
| v1len = blen-n; |
| } else { |
| v0len = blen; |
| v1len = 0; |
| } |
| u0 = a; |
| u1 = a + u0len; |
| v0 = b; |
| v1 = b + v0len; |
| |
| /* room for the partial products */ |
| t = mallocz(Dbytes*5*(2*n+1), 1); |
| if(t == nil) |
| sysfatal("mpkaratsuba: %r"); |
| u0v0 = t; |
| u1v1 = t + (2*n+1); |
| diffprod = t + 2*(2*n+1); |
| res = t + 3*(2*n+1); |
| reslen = 4*n+1; |
| |
| /* t[0] = (u1-u0) */ |
| sign = 1; |
| if(mpveccmp(u1, u1len, u0, u0len) < 0){ |
| sign = -1; |
| mpvecsub(u0, u0len, u1, u1len, u0v0); |
| } else |
| mpvecsub(u1, u1len, u0, u1len, u0v0); |
| |
| /* t[1] = (v0-v1) */ |
| if(mpveccmp(v0, v0len, v1, v1len) < 0){ |
| sign *= -1; |
| mpvecsub(v1, v1len, v0, v1len, u1v1); |
| } else |
| mpvecsub(v0, v0len, v1, v1len, u1v1); |
| |
| /* t[4:5] = (u1-u0)*(v0-v1) */ |
| mpvecmul(u0v0, u0len, u1v1, v0len, diffprod); |
| |
| /* t[0:1] = u1*v1 */ |
| memset(t, 0, 2*(2*n+1)*Dbytes); |
| if(v1len > 0) |
| mpvecmul(u1, u1len, v1, v1len, u1v1); |
| |
| /* t[2:3] = u0v0 */ |
| mpvecmul(u0, u0len, v0, v0len, u0v0); |
| |
| /* res = u0*v0<<n + u0*v0 */ |
| mpvecadd(res, reslen, u0v0, u0len+v0len, res); |
| mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n); |
| |
| /* res += u1*v1<<n + u1*v1<<2*n */ |
| if(v1len > 0){ |
| mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n); |
| mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n); |
| } |
| |
| /* res += (u1-u0)*(v0-v1)<<n */ |
| if(sign < 0) |
| mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n); |
| else |
| mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n); |
| memmove(p, res, (alen+blen)*Dbytes); |
| |
| free(t); |
| } |
| |
| #define KARATSUBAMIN 32 |
| |
| void |
| mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) |
| { |
| int i; |
| mpdigit d; |
| mpdigit *t; |
| |
| /* both mpvecdigmuladd and karatsuba are fastest when a is the longer vector */ |
| if(alen < blen){ |
| i = alen; |
| alen = blen; |
| blen = i; |
| t = a; |
| a = b; |
| b = t; |
| } |
| if(blen == 0){ |
| memset(p, 0, Dbytes*(alen+blen)); |
| return; |
| } |
| |
| if(alen >= KARATSUBAMIN && blen > 1){ |
| /* O(n^1.585) */ |
| mpkaratsuba(a, alen, b, blen, p); |
| } else { |
| /* O(n^2) */ |
| for(i = 0; i < blen; i++){ |
| d = b[i]; |
| if(d != 0) |
| mpvecdigmuladd(a, alen, d, &p[i]); |
| } |
| } |
| } |
| |
| void |
| mpmul(mpint *b1, mpint *b2, mpint *prod) |
| { |
| mpint *oprod; |
| |
| oprod = nil; |
| if(prod == b1 || prod == b2){ |
| oprod = prod; |
| prod = mpnew(0); |
| } |
| |
| prod->top = 0; |
| mpbits(prod, (b1->top+b2->top+1)*Dbits); |
| mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p); |
| prod->top = b1->top+b2->top+1; |
| prod->sign = b1->sign*b2->sign; |
| mpnorm(prod); |
| |
| if(oprod != nil){ |
| mpassign(prod, oprod); |
| mpfree(prod); |
| } |
| } |