18

I wanted to use NTT for fast squaring (see Fast bignum square computation), but the result is slow even for really big numbers .. more than 12000 bits.

So my question is:

  1. Is there a way to optimize my NTT transform? I did not mean to speed it by parallelism (threads); this is low-level layer only.
  2. Is there a way to speed up my modular arithmetics?

This is my (already optimized) source code in C++ for NTT (it's complete and 100% working in C++ whitout any need for third-party libs and should also be thread-safe. Beware the source array is used as a temporary!!!, Also it cannot transform the array to itself).

//--------------------------------------------------------------------------- class fourier_NTT // Number theoretic transform { public: DWORD r,L,p,N; DWORD W,iW,rN; fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; } // main interface void NTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast NTT(DWORD src[n]) void INTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast INTT(DWORD src[n]) // Helper functions bool init(DWORD n); // init r,L,p,W,iW,rN void NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = fast NTT(DWORD src[n]) // Only for testing void NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow NTT(DWORD src[n]) void INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow INTT(DWORD src[n]) // DWORD arithmetics DWORD shl(DWORD a); DWORD shr(DWORD a); // Modular arithmetics DWORD mod(DWORD a); DWORD modadd(DWORD a,DWORD b); DWORD modsub(DWORD a,DWORD b); DWORD modmul(DWORD a,DWORD b); DWORD modpow(DWORD a,DWORD b); }; //--------------------------------------------------------------------------- void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n) { if (n>0) init(n); NTT_fast(dst,src,N,W); // NTT_slow(dst,src,N,W); } //--------------------------------------------------------------------------- void fourier_NTT::INTT(DWORD *dst,DWORD *src,DWORD n) { if (n>0) init(n); NTT_fast(dst,src,N,iW); for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN); // INTT_slow(dst,src,N,W); } //--------------------------------------------------------------------------- bool fourier_NTT::init(DWORD n) { // (max(src[])^2)*n < p else NTT overflow can ocur !!! r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit // r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit // r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit // r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit N=n; // size of vectors [DWORDs] W=modpow(r, L); // Wn for NTT iW=modpow(r,p-1-L); // Wn for INTT rN=modpow(n,p-2 ); // scale for INTT return true; } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w) { if (n<=1) { if (n==1) dst[0]=src[0]; return; } DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w); // reorder even,odd for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j]; for ( j=1;i<n ;i++,j+=2) dst[i]=src[j]; // recursion NTT_fast(src ,dst ,n2,w2); // even NTT_fast(src+n2,dst+n2,n2,w2); // odd // restore results for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w)) { a0=src[i]; a1=modmul(src[j],w2); dst[i]=modadd(a0,a1); dst[j]=modsub(a0,a1); } } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w) { DWORD i,j,wj,wi,a,n2=n>>1; for (wj=1,j=0;j<n;j++) { a=0; for (wi=1,i=0;i<n;i++) { a=modadd(a,modmul(wi,src[i])); wi=modmul(wi,wj); } dst[j]=a; wj=modmul(wj,w); } } //--------------------------------------------------------------------------- void fourier_NTT::INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w) { DWORD i,j,wi=1,wj=1,a,n2=n>>1; for (wj=1,j=0;j<n;j++) { a=0; for (wi=1,i=0;i<n;i++) { a=modadd(a,modmul(wi,src[i])); wi=modmul(wi,wj); } dst[j]=modmul(a,rN); wj=modmul(wj,iW); } } //--------------------------------------------------------------------------- DWORD fourier_NTT::shl(DWORD a) { return (a<<1)&0xFFFFFFFE; } DWORD fourier_NTT::shr(DWORD a) { return (a>>1)&0x7FFFFFFF; } //--------------------------------------------------------------------------- DWORD fourier_NTT::mod(DWORD a) { DWORD bb; for (bb=p;(DWORD(a)>DWORD(bb))&&(!DWORD(bb&0x80000000));bb=shl(bb)); for (;;) { if (DWORD(a)>=DWORD(bb)) a-=bb; if (bb==p) break; bb =shr(bb); } return a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modadd(DWORD a,DWORD b) { DWORD d,cy; a=mod(a); b=mod(b); d=a+b; cy=(shr(a)+shr(b)+shr((a&1)+(b&1)))&0x80000000; if (cy) d-=p; if (DWORD(d)>=DWORD(p)) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modsub(DWORD a,DWORD b) { DWORD d; a=mod(a); b=mod(b); d=a-b; if (DWORD(a)<DWORD(b)) d+=p; if (DWORD(d)>=DWORD(p)) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modmul(DWORD a,DWORD b) { // b bez orezania ! int i; DWORD d; a=mod(a); for (d=0,i=0;i<32;i++) { if (DWORD(a&1)) d=modadd(d,b); a=shr(a); b=modadd(b,b); } return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modpow(DWORD a,DWORD b) { // a,b bez orezania ! int i; DWORD d=1; for (i=0;i<32;i++) { d=modmul(d,d); if (DWORD(b&0x80000000)) d=modmul(d,a); b=shl(b); } return d; } //--------------------------------------------------------------------------- 

Example of usage of my NTT class:

fourier_NTT ntt; const DWORD n=32 DWORD x[N]={0,1,2,3,....31},y[N]={32,33,34,35,...63},z[N]; ntt.NTT(z,x,N); // z[N]=NTT(x[N]), also init constants for N ntt.NTT(x,y); // x[N]=NTT(y[N]), no recompute of constants, use last N // modular convolution y[]=z[].x[] for (i=0;i<n;i++) y[i]=ntt.modmul(z[i],x[i]); ntt.INTT(x,y); // x[N]=INTT(y[N]), no recompute of constants, use last N // x[]=convolution of original x[].y[] 

Some measurements before optimizations (non Class NTT):

a = 0.98765588997654321000 | 389*32 bits looped 1x times sqr1[ 3.177 ms ] fast sqr sqr2[ 720.419 ms ] NTT sqr mul1[ 5.588 ms ] simpe mul mul2[ 3.172 ms ] karatsuba mul mul3[ 1053.382 ms ] NTT mul 

Some measurements after my optimizations (current code, lower recursion parameter size/count, and better modular arithmetics):

a = 0.98765588997654321000 | 389*32 bits looped 1x times sqr1[ 3.214 ms ] fast sqr sqr2[ 208.298 ms ] NTT sqr mul1[ 5.564 ms ] simpe mul mul2[ 3.113 ms ] karatsuba mul mul3[ 302.740 ms ] NTT mul 

Check the NTT mul and NTT sqr times (my optimizations speed it up little over 3x times). It's only 1x times loop so it's not very precise (error ~ 10%), but the speedup is noticeable even now (normally I loop it 1000x and more, but my NTT is too slow for that).

You can use my code freely... Just keep my nick and/or link to this page somewhere (rem in code, readme.txt, about or whatever). I hope it helps... (I did not see C++ source for fast NTTs anywhere so I had to write it by myself). Roots of unity were tested for all accepted N, see the fourier_NTT::init(DWORD n) function.

P.S.: For more information about NTT, see Translation from Complex-FFT to Finite-Field-FFT. This code is based on my posts inside that link.

[edit1:] Further changes in the code

I managed to further optimize my modular arithmetics, by exploiting that modulo prime is allways 0xC0000001 and eliminating unnecessary calls. The resulting speedup is stunning (more than 40x times) now and NTT multiplication is faster than karatsuba after about the 1500 * 32 bits threshold. BTW, the speed of my NTT is now the same as my optimized DFFT on 64-bit doubles.

Some measurements:

a = 0.98765588997654321000 | 1553*32bits looped 10x times mul2[ 28.585 ms ] karatsuba mul mul3[ 26.311 ms ] NTT mul 

New source code for modular arithmetics:

//--------------------------------------------------------------------------- DWORD fourier_NTT::mod(DWORD a) { if (a>p) a-=p; return a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modadd(DWORD a,DWORD b) { DWORD d,cy; if (a>p) a-=p; if (b>p) b-=p; d=a+b; cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000; if (cy ) d-=p; if (d>p) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modsub(DWORD a,DWORD b) { DWORD d; if (a>p) a-=p; if (b>p) b-=p; d=a-b; if (a<b) d+=p; if (d>p) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modmul(DWORD a,DWORD b) { DWORD _a,_b,_p; _a=a; _b=b; _p=p; asm { mov eax,_a mov ebx,_b mul ebx // H(edx),L(eax) = eax * ebx mov ebx,_p div ebx // eax = H(edx),L(eax) / ebx mov _a,edx // edx = H(edx),L(eax) % ebx } return _a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modpow(DWORD a,DWORD b) { // b bez orezania! int i; DWORD d=1; if (a>p) a-=p; for (i=0;i<32;i++) { d=modmul(d,d); if (DWORD(b&0x80000000)) d=modmul(d,a); b<<=1; } return d; } //--------------------------------------------------------------------------- 

As you can see, functions shl and shr are no more used. I think that modpow can be further optimized, but it's not a critical function because it is called only very few times. The most critical function is modmul, and that seems to be in the best shape possible.

Further questions:

  • Is there any other option to speedup NTT?
  • Are my optimizations of modular arithmetics safe? (Results seem to be the same, but I could miss something.)

[edit2] New optimizations

a = 0.99991970486 | 2000*32 bits looped 10x sqr1[ 13.908 ms ] fast sqr sqr2[ 13.649 ms ] NTT sqr mul1[ 19.726 ms ] simpe mul mul2[ 31.808 ms ] karatsuba mul mul3[ 19.373 ms ] NTT mul 

I implemented all the usable stuff from all of your comments (thanks for the insight).

Speedups:

  • +2.5% by removing unnecessary safety mods (Mandalf The Beige)
  • +34.9% by use of precomputed W,iW powers (Mysticial)
  • +35% total

Actual full source code:

//--------------------------------------------------------------------------- //--- Number theoretic transforms: 2.03 ------------------------------------- //--------------------------------------------------------------------------- #ifndef _fourier_NTT_h #define _fourier_NTT_h //--------------------------------------------------------------------------- //--------------------------------------------------------------------------- class fourier_NTT // Number theoretic transform { public: DWORD r,L,p,N; DWORD W,iW,rN; // W=(r^L) mod p, iW=inverse W, rN = inverse N DWORD *WW,*iWW,NN; // Precomputed (W,iW)^(0,..,NN-1) powers // Internals fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; WW=NULL; iWW=NULL; NN=0; } ~fourier_NTT(){ _free(); } void _free(); // Free precomputed W,iW powers tables void _alloc(DWORD n); // Allocate and precompute W,iW powers tables // Main interface void NTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast NTT(DWORD src[n]) void iNTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast INTT(DWORD src[n]) // Helper functions bool init(DWORD n); // init r,L,p,W,iW,rN void NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = fast NTT(DWORD src[n]) void NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2); // Only for testing void NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow NTT(DWORD src[n]) void iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow INTT(DWORD src[n]) // Modular arithmetics (optimized, but it works only for p >= 0x80000000!!!) DWORD mod(DWORD a); DWORD modadd(DWORD a,DWORD b); DWORD modsub(DWORD a,DWORD b); DWORD modmul(DWORD a,DWORD b); DWORD modpow(DWORD a,DWORD b); }; //--------------------------------------------------------------------------- //--------------------------------------------------------------------------- void fourier_NTT::_free() { NN=0; if ( WW) delete[] WW; WW=NULL; if (iWW) delete[] iWW; iWW=NULL; } //--------------------------------------------------------------------------- void fourier_NTT::_alloc(DWORD n) { if (n<=NN) return; DWORD *tmp,i,w; tmp=new DWORD[n]; if ((NN)&&( WW)) for (i=0;i<NN;i++) tmp[i]= WW[i]; if ( WW) delete[] WW; WW=tmp; WW[0]=1; for (i=NN?NN:1,w= WW[i-1];i<n;i++){ w=modmul(w, W); WW[i]=w; } tmp=new DWORD[n]; if ((NN)&&(iWW)) for (i=0;i<NN;i++) tmp[i]=iWW[i]; if (iWW) delete[] iWW; iWW=tmp; iWW[0]=1; for (i=NN?NN:1,w=iWW[i-1];i<n;i++){ w=modmul(w,iW); iWW[i]=w; } NN=n; } //--------------------------------------------------------------------------- void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n) { if (n>0) init(n); NTT_fast(dst,src,N,WW,1); // NTT_fast(dst,src,N,W); // NTT_slow(dst,src,N,W); } //--------------------------------------------------------------------------- void fourier_NTT::iNTT(DWORD *dst,DWORD *src,DWORD n) { if (n>0) init(n); NTT_fast(dst,src,N,iWW,1); // NTT_fast(dst,src,N,iW); for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN); // iNTT_slow(dst,src,N,W); } //--------------------------------------------------------------------------- bool fourier_NTT::init(DWORD n) { // (max(src[])^2)*n < p else NTT overflow can ocur!!! r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit // r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit // r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit // r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit N=n; // Size of vectors [DWORDs] W=modpow(r, L); // Wn for NTT iW=modpow(r,p-1-L); // Wn for INTT rN=modpow(n,p-2 ); // Scale for INTT _alloc(n>>1); // Precompute W,iW powers return true; } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w) { if (n<=1) { if (n==1) dst[0]=src[0]; return; } DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w); // Reorder even,odd for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j]; for ( j=1;i<n ;i++,j+=2) dst[i]=src[j]; // Recursion NTT_fast(src ,dst ,n2,w2); // Even NTT_fast(src+n2,dst+n2,n2,w2); // Odd // Restore results for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w)) { a0=src[i]; a1=modmul(src[j],w2); dst[i]=modadd(a0,a1); dst[j]=modsub(a0,a1); } } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2) { if (n<=1) { if (n==1) dst[0]=src[0]; return; } DWORD i,j,a0,a1,n2=n>>1; // Reorder even,odd for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j]; for ( j=1;i<n ;i++,j+=2) dst[i]=src[j]; // Recursion i=i2<<1; NTT_fast(src ,dst ,n2,w2,i); // Even NTT_fast(src+n2,dst+n2,n2,w2,i); // Odd // Restore results for (i=0,j=n2;i<n2;i++,j++,w2+=i2) { a0=src[i]; a1=modmul(src[j],*w2); dst[i]=modadd(a0,a1); dst[j]=modsub(a0,a1); } } //--------------------------------------------------------------------------- void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w) { DWORD i,j,wj,wi,a; for (wj=1,j=0;j<n;j++) { a=0; for (wi=1,i=0;i<n;i++) { a=modadd(a,modmul(wi,src[i])); wi=modmul(wi,wj); } dst[j]=a; wj=modmul(wj,w); } } //--------------------------------------------------------------------------- void fourier_NTT::iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w) { DWORD i,j,wi=1,wj=1,a; for (wj=1,j=0;j<n;j++) { a=0; for (wi=1,i=0;i<n;i++) { a=modadd(a,modmul(wi,src[i])); wi=modmul(wi,wj); } dst[j]=modmul(a,rN); wj=modmul(wj,iW); } } //--------------------------------------------------------------------------- DWORD fourier_NTT::mod(DWORD a) { if (a>p) a-=p; return a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modadd(DWORD a,DWORD b) { DWORD d,cy; //if (a>p) a-=p; //if (b>p) b-=p; d=a+b; cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000; if (cy ) d-=p; if (d>p) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modsub(DWORD a,DWORD b) { DWORD d; //if (a>p) a-=p; //if (b>p) b-=p; d=a-b; if (a<b) d+=p; if (d>p) d-=p; return d; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modmul(DWORD a,DWORD b) { DWORD _a,_b,_p; _a=a; _b=b; _p=p; asm { mov eax,_a mov ebx,_b mul ebx // H(edx),L(eax) = eax * ebx mov ebx,_p div ebx // eax = H(edx),L(eax) / ebx mov _a,edx // edx = H(edx),L(eax) % ebx } return _a; } //--------------------------------------------------------------------------- DWORD fourier_NTT::modpow(DWORD a,DWORD b) { // b is not mod(p)! int i; DWORD d=1; //if (a>p) a-=p; for (i=0;i<32;i++) { d=modmul(d,d); if (DWORD(b&0x80000000)) d=modmul(d,a); b<<=1; } return d; } //--------------------------------------------------------------------------- //--------------------------------------------------------------------------- #endif //--------------------------------------------------------------------------- //--------------------------------------------------------------------------- 

There is still the possibility to use less heap trashing by separating NTT_fast to two functions. One with WW[] and the other with iWW[] which leads to one parameter less in recursion calls. But I do not expect much from it (32-bit pointer only) and rather have one function for better code management in the future. Many functions are dormant now (for testing) Like slow variants, mod and the older fast function (with w parameter instead of *w2,i2).

To avoid overflows for big datasets, limit input numbers to p/4 bits. Where p is number of bits per NTT element so for this 32 bit version use max (32 bit/4 -> 8 bit) input values.

[edit3] Simple string bigint multiplication for testing

//--------------------------------------------------------------------------- char* mul_NTT(const char *sx,const char *sy) { char *s; int i,j,k,n; // n = min power of 2 <= 2 max length(x,y) for (i=0;sx[i];i++); for (n=1;n<i;n<<=1); i--; for (j=0;sx[j];j++); for (n=1;n<j;n<<=1); n<<=1; j--; DWORD *x,*y,*xx,*yy,a; x=new DWORD[n]; xx=new DWORD[n]; y=new DWORD[n]; yy=new DWORD[n]; // Zero padding for (k=0;i>=0;i--,k++) x[k]=sx[i]-'0'; for (;k<n;k++) x[k]=0; for (k=0;j>=0;j--,k++) y[k]=sy[j]-'0'; for (;k<n;k++) y[k]=0; //NTT fourier_NTT ntt; ntt.NTT(xx,x,n); ntt.NTT(yy,y); // Convolution for (i=0;i<n;i++) xx[i]=ntt.modmul(xx[i],yy[i]); //INTT ntt.iNTT(yy,xx); //suma a=0; s=new char[n+1]; for (i=0;i<n;i++) { a+=yy[i]; s[n-i-1]=(a%10)+'0'; a/=10; } s[n]=0; delete[] x; delete[] xx; delete[] y; delete[] yy; return s; } //--------------------------------------------------------------------------- 

I use AnsiString's, so I port it to char* hopefully, I did not do some mistake. It looks like it works properly (in comparison to the AnsiString version).

  • sx,sy are decadic integer numbers
  • Returns allocated string (char*)=sx*sy

This is only ~4 bit per 32 bit data word so there is no risk of overflow, but it is slower of course. In my bignum lib I use a binary representation and use 8 bit chunks per 32-bit WORD for NTT. More than that is risky if N is big ...

Have fun with this

19
  • 3
    You probably realized that the multiply-mod divisions will kill performance. Take a look at some of the fast multiply-mod algorithms: Montgomery Reduction, Barrett's Reduction. The second method is for large numbers, but the method works just as well on word-size integers. There's a modern high-performance implementation using this that they're considering putting into GMP. Commented Jun 13, 2014 at 17:48
  • @Mysticial I do not see any benefit for modmul (I do not use big numbers just DWORDs this is not used for cryptography) but there is possible speedup for modpow on DWORDs. I do not see a straightforward algorithm on the Montgomery Reduction just few sentences about it. Do you have something working to compare actual speeds? BTW +1 for nice talk about NTT why did I not find it when I was coding this :( would save me a lot of time. Commented Jun 13, 2014 at 18:09
  • I don't (yet), but I'm considering implementing a 64-bit small prime NTT for y-cruncher. The slide that I linked to has benchmarks for the various butterflies. The fastest ones run in about 10 cycles per butterfly. (10 cycles by itself is already much cheaper than an integer division.) It was done by GMP developers so they know what they are doing. Commented Jun 13, 2014 at 18:14
  • 1
    One last thing: You can eliminate half your multiply-mod calls if you pre-compute the twiddle factors. Commented Jun 13, 2014 at 19:00
  • 1
    If you're interested, I just finished a prototype implementation of this NTT algorithm. You can take a look here. I didn't try very hard to optimize it, but it uses multiple primes and can do a billion digit multiply (3.33 * 10^9 x 3.33 * 10^9 -> 6.64 *10^9 bits) in about 37 seconds on my 4 GHz Haswell. If you compile and run it, it will show a very simple UI that lets you run various test sizes. Commented Nov 19, 2014 at 9:30

1 Answer 1

5

First off, thank you very much for posting and making it free to use. I really appreciate that.

I was able to use some bit tricks to eliminate some branching, rearranged the main loop, and modified the assembly, and was able to get a 1.35x speedup.

Also, I added a preprocessor condition for 64 bit, seeing as Visual Studio doesn't allow inline assembly in 64 bit mode (thank you Microsoft; feel free to go screw yourself).

Something strange happened when I was optimizing the modsub() function. I rewrote it using bit hacks like I did modadd (which was faster). But for some reason, the bit wise version of modsub was slower. Not sure why. Might just be my computer.

// // Mandalf The Beige // Based on: // Spektre // http://stackoverflow.com/questions/18577076/modular-arithmetics-and-ntt-finite-field-dft-optimizations // // This code may be freely used however you choose, so long as it is accompanied by this notice. // #ifndef H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR #define H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR #include <string.h> #ifndef uint32 #define uint32 unsigned long int #endif #ifndef uint64 #define uint64 unsigned long long int #endif class fast_ntt // number theoretic transform { public: fast_ntt() { r = 0; L = 0; W = 0; iW = 0; rN = 0; } // main interface void NTT(uint32 *dst, uint32 *src, uint32 n = 0); // uint32 dst[n] = fast NTT(uint32 src[n]) void INTT(uint32 *dst, uint32 *src, uint32 n = 0); // uint32 dst[n] = fast INTT(uint32 src[n]) // helper functions private: bool init(uint32 n); // init r,L,p,W,iW,rN void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n]) void NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n]) void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w); // only for testing void NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = slow NTT(uint32 src[n]) void INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = slow INTT(uint32 src[n]) // uint32 arithmetics // modular arithmetics inline uint32 modadd(uint32 a, uint32 b); inline uint32 modsub(uint32 a, uint32 b); inline uint32 modmul(uint32 a, uint32 b); inline uint32 modpow(uint32 a, uint32 b); uint32 r, L, N;//, p; uint32 W, iW, rN; const uint32 p = 0xC0000001; }; //--------------------------------------------------------------------------- void fast_ntt::NTT(uint32 *dst, uint32 *src, uint32 n) { if (n > 0) { init(n); } NTT_fast(dst, src, N, W); // NTT_slow(dst,src,N,W); } //--------------------------------------------------------------------------- void fast_ntt::INTT(uint32 *dst, uint32 *src, uint32 n) { if (n > 0) { init(n); } NTT_fast(dst, src, N, iW); for (uint32 i = 0; i<N; i++) { dst[i] = modmul(dst[i], rN); } // INTT_slow(dst,src,N,W); } //--------------------------------------------------------------------------- bool fast_ntt::init(uint32 n) { // (max(src[])^2)*n < p else NTT overflow can ocur !!! r = 2; //p = 0xC0000001; if ((n < 2) || (n > 0x10000000)) { r = 0; L = 0; W = 0; // p = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit // r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit // r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit // r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit N = n; // size of vectors [uint32s] W = modpow(r, L); // Wn for NTT iW = modpow(r, p - 1 - L); // Wn for INTT rN = modpow(n, p - 2); // scale for INTT return true; } //--------------------------------------------------------------------------- void fast_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w) { if(n > 1) { if(dst != src) { NTT_calc(dst, src, n, w); } else { uint32* temp = new uint32[n]; NTT_calc(temp, src, n, w); memcpy(dst, temp, n * sizeof(uint32)); delete [] temp; } } else if(n == 1) { dst[0] = src[0]; } } void fast_ntt::NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w) { if (n > 1) { uint32* temp = new uint32[n]; memcpy(temp, src, n * sizeof(uint32)); NTT_calc(dst, temp, n, w); delete[] temp; } else if (n == 1) { dst[0] = src[0]; } } void fast_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w) { if(n > 1) { uint32 i, j, a0, a1, n2 = n >> 1, w2 = modmul(w, w); // reorder even,odd for (i = 0, j = 0; i < n2; i++, j += 2) { dst[i] = src[j]; } for (j = 1; i < n; i++, j += 2) { dst[i] = src[j]; } // recursion if(n2 > 1) { NTT_calc(src, dst, n2, w2); // even NTT_calc(src + n2, dst + n2, n2, w2); // odd } else if(n2 == 1) { src[0] = dst[0]; src[1] = dst[1]; } // restore results w2 = 1, i = 0, j = n2; a0 = src[i]; a1 = src[j]; dst[i] = modadd(a0, a1); dst[j] = modsub(a0, a1); while (++i < n2) { w2 = modmul(w2, w); j++; a0 = src[i]; a1 = modmul(src[j], w2); dst[i] = modadd(a0, a1); dst[j] = modsub(a0, a1); } } } //--------------------------------------------------------------------------- void fast_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w) { uint32 i, j, wj, wi, a, n2 = n >> 1; for (wj = 1, j = 0; j < n; j++) { a = 0; for (wi = 1, i = 0; i < n; i++) { a = modadd(a, modmul(wi, src[i])); wi = modmul(wi, wj); } dst[j] = a; wj = modmul(wj, w); } } //--------------------------------------------------------------------------- void fast_ntt::INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w) { uint32 i, j, wi = 1, wj = 1, a, n2 = n >> 1; for (wj = 1, j = 0; j < n; j++) { a = 0; for (wi = 1, i = 0; i < n; i++) { a = modadd(a, modmul(wi, src[i])); wi = modmul(wi, wj); } dst[j] = modmul(a, rN); wj = modmul(wj, iW); } } //--------------------------------------------------------------------------- uint32 fast_ntt::modadd(uint32 a, uint32 b) { uint32 d; d = a + b; if(d < a) { d -= p; } if (d >= p) { d -= p; } return d; } //--------------------------------------------------------------------------- uint32 fast_ntt::modsub(uint32 a, uint32 b) { uint32 d; d = a - b; if (d > a) { d += p; } return d; } //--------------------------------------------------------------------------- uint32 fast_ntt::modmul(uint32 a, uint32 b) { uint32 _a = a; uint32 _b = b; // Original uint32 _p = p; __asm { mov eax, _a; mul _b; div _p; mov eax, edx; }; } uint32 fast_ntt::modpow(uint32 a, uint32 b) { //* uint64 D, M, A, P; P = p; A = a; M = 0llu - (b & 1); D = (M & A) | ((~M) & 1); while ((b >>= 1) != 0) { A = modmul(A, A); //A = (A * A) % P; if ((b & 1) == 1) { //D = (D * A) % P; D = modmul(D, A); } } return (uint32)D; } 

New modmul

uint32 fast_ntt::modmul(uint32 a, uint32 b) { uint32 _a = a; uint32 _b = b; __asm { mov eax, a; mul b; mov ebx, eax; mov eax, 2863311530; mov ecx, edx; mul edx; shld edx, eax, 1; mov eax, 3221225473; mul edx; sub ebx, eax; mov eax, 3221225473; sbb ecx, edx; jc addback; neg ecx; and ecx, eax; sub ebx, ecx; sub ebx, eax; sbb edx, edx; and eax, edx; addback: add eax, ebx; }; } 

[EDIT] Spektre, based on your feedback I changed the modadd & modsub back to their original. I also realized I made some changes to the recursive NTT function I shouldn't have.

[EDIT2] Removed unneeded if statements and bitwise functions.

[EDIT3] Added new modmul inline assembly.

Sign up to request clarification or add additional context in comments.

32 Comments

just few things I use 32bit compiler (not MS but Borland/Embarcadero) so operations like % are slower then conditional substraction. Also code optimizations are very different so I see no speed gain in this for mine setup (on x64 it can be different but I use x32) also uint64 is problem on old compilers I use (64 bit arithmetics is very buggy) the only possible gain for me are operations +,- when I have the time will test it more but I am afraid your speed gain is due to different setup only...
As I suspected your modadd slows down mine NTT about 12%, modsub is the same (did not see it at first look due to your different code formating) so no this is not a good way for mine programing and HW environment.
Sorry to hear that. I'm curious though, did the inline assembly improve anything at all?
in BDS2006 is asm very tricky because it (I assume) store some states into heap which slow thing down a bit so many times C++ functions are faster then the same asm code (because of asm directive) but for more complex or slow operations is asm of course faster ... (on MS and GCC compilers this does not happen at all from what I read and heard)
I'm sorry for the confusion. I was referring to the modifications I made to the assembly in the modmul() function. I saw some major speedups on my machine, and I was wondering if you did as well?
|

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.