I am attempting to optimize a piece of C code which aims to multiply a series of pairs of unsigned shorts and add the result. I am only concerned about the high 16 bits of the result, and I can guarantee that the sum of the multiples will fit in a 32-bit value. I initially coded this in C, and then rewrote it to use SSE2 intrinsics (slowest), and then rewrote it in SSE2 assembler (fastest). I am not an expert at x86 assembler and would appreciate any recommendations on how to speed this code up. Speed is the priority, this is a tight inner loop. It is OK to assume that input will be valid. In addition, portability is not a major concern, this code will only be used on computers with Intel i5 or i7 processors. Attend:
C
register int i; uint16_t* iw_ptr = n->iw; uint16_t* nw_ptr = n->nw; register uint32_t accvalue = 0; if (id < 2 * I_CNT) { for (i = 0; i < I_CNT; i++) { accvalue += (uint32_t) i_ptr[i] * (uint32_t) iw_ptr[i]; } } for (i = 0; i < id; i++) { accvalue += (uint32_t) n_ptr[i] * (uint32_t) nw_ptr[i]; } value = (uint16_t) (accvalue >> 16); C SSE2
#define SSE2_I ((INPUT_COUNT+7)/8) #define SSE2_N ((NEURON_COUNT+7)/8) register int i; __m128i mm_sums; __m128i mm_arg1; __m128i mm_arg2; __m128i mm_accum = _mm_setzero_si128(); __m128i* mm_iptr = (__m128i*) i_ptr; __m128i* mm_nptr = (__m128i*) n_ptr; __m128i* mm_iwptr = (__m128i*) n->iw_ptr; __m128i* mm_nwptr = (__m128i*) n->nw_ptr; int id_8 = (id + 7)/8; // Round up to nearest multiple of 8 if (id < 2 * I_CNT) { for (i = 0; i < SSE2_I; i++) { mm_arg1 = _mm_loadu_si128(mm_iptr+i); mm_arg2 = _mm_loadu_si128(mm_iwptr+i); mm_sums = _mm_mulhi_epu16(mm_arg1, mm_arg2); mm_accum = _mm_adds_epu16(mm_accum, mm_sums); } } for (i = 0; i < id_8; i++) { mm_arg1 = _mm_loadu_si128(mm_nptr+i); mm_arg2 = _mm_loadu_si128(mm_nwptr+i); mm_sums = _mm_mulhi_epu16(mm_arg1, mm_arg2); mm_accum = _mm_adds_epu16(mm_accum, mm_sums); } _mm_storeu_si128(mm_accum_mem, mm_accum); for (i = 0; i < 8; i++) { value += *(((uint16_t*) mm_accum_mem) + i); } SSE2 and x86 Assembler
#define SSE2_I ((I_CNT+7)/8) #define SSE2_N ((N_CNT+7)/8) __m128i* mm_iptr = (__m128i*) i_ptr; __m128i* mm_nptr = (__m128i*) n_ptr; __m128i* mm_iwptr = (__m128i*) n->iw_ptr; __m128i* mm_nwptr = (__m128i*) n->nw_ptr; int id_8 = (id + 7)/8; // Divide by 8, round up asm( "MOVL %5, %%eax \n\t" // Current ID "MOVL %6, %%ebx \n\t" // 2*I_CNT "CMP %%eax, %%ebx \n\t" // If ID >= 2 * input count "JGE Nstart1 \n\t" // Skip the first step "MOVL %1, %%eax \n\t" "MOVL %3, %%ebx \n\t" "MOVDQA (%%eax), %%xmm0 \n\t" "MOVDQA (%%ebx), %%xmm1 \n\t" "PMULHUW %%xmm1, %%xmm0 \n\t" #if SSE2_I > 1 "MOVDQA 0x10(%%eax), %%xmm1 \n\t" "MOVDQA 0x10(%%ebx), %%xmm2 \n\t" "PMULHUW %%xmm1, %%xmm2 \n\t" "PADDUSW %%xmm2, %%xmm0 \n\t" #endif #if SSE2_I > 2 "MOVDQA 0x20(%%eax), %%xmm1 \n\t" "MOVDQA 0x20(%%ebx), %%xmm2 \n\t" "PMULHUW %%xmm1, %%xmm2 \n\t" "PADDUSW %%xmm2, %%xmm0 \n\t" #endif #if SSE2_I > 3 "MOVDQA 0x30(%%eax), %%xmm1 \n\t" "MOVDQA 0x30(%%ebx), %%xmm2 \n\t" "PMULHUW %%xmm1, %%xmm2 \n\t" "PADDUSW %%xmm2, %%xmm0 \n\t" #endif "JMP Nstart2 \n\t" "Nstart1: \n\t" // This is our first multiplication "MOVL %2, %%edi \n\t" "MOVL %4, %%esi \n\t" "MOVDQA (%%edi), %%xmm0 \n\t" "MOVDQA (%%esi), %%xmm1 \n\t" "PMULHUW %%xmm1, %%xmm0 \n\t" "JMP Nstart3 \n\t" "Nstart2: \n\t" // This is not our first multiplication "MOVL %2, %%edi \n\t" "MOVL %4, %%esi \n\t" "MOVDQA (%%edi), %%xmm1 \n\t" "MOVDQA (%%esi), %%xmm2 \n\t" "PMULHUW %%xmm1, %%xmm2 \n\t" "PADDUSW %%xmm2, %%xmm0 \n\t" "Nstart3: \n\t" "MOVL %7, %%ebx \n\t" // Current ID, divided by 8, rounded up. // The number of rounds we have to do "DEC %%ebx \n\t" // If it is now 0 or -1 "JLE Endloop \n\t" // We don't have to do any more rounds "MOVL $0x10, %%eax \n\t" // The offset "Loop: \n\t" "MOVDQA (%%edi, %%eax), %%xmm1 \n\t" "MOVDQA (%%esi, %%eax), %%xmm2 \n\t" "PMULHUW %%xmm1, %%xmm2 \n\t" "PADDUSW %%xmm2, %%xmm0 \n\t" "ADDL $0x10, %%eax \n\t" "DEC %%ebx \n\t" // The round we just did "JNE Loop \n\t" // If not zero, do it again "Endloop: \n\t" /** * PREPARE FOR THE ADDING OF THE WORDS * xmm0 3 2 1 0 * xmm1 0 0 3 2 0b00001110 * xmm0 X X 1+3 0+2 PADDUSW * xmm1 X X X 1+3 0b00000001 * xmm0 X X X sum PADDUSW * This however is still two words - we were shuffling doublewords... * But that's not all we can shuffle! * Shuffle lowwords 0b00000001 * PADDUSW * Extract to register as a 16-bit word */ "PSHUFD $0x0E, %%xmm0, %%xmm1 \n\t" "PADDUSW %%xmm1, %%xmm0 \n\t" "PSHUFD $0x01, %%xmm0, %%xmm1 \n\t" "PADDUSW %%xmm1, %%xmm0 \n\t" "PSHUFLW $0x01, %%xmm0, %%xmm1 \n\t" "PADDUSW %%xmm1, %%xmm0 \n\t" "PEXTRW $0x01, %%xmm0, %%eax \n\t" : "=d" (value) // Outputs : "g" (mm_iptr), "g" (mm_nptr), "g" (mm_iwptr), "g" (mm_nwptr), "g" (id), "g" (2*INPUT_COUNT), "g" (id_8) // Inputs : "%xmm0", "%xmm1", "%xmm2", "%eax", "%ebx", "%edi", "%esi" // Clobbered ); I have tried to:
Use the prefetchnta instruction on a few of the pointers (like %2 and %4), this has only increased the execution time.
Use more of the XMM registers by grouping the MOVDQA, MOVDQA and PMULHUW, PADDUSW to up to three iterations at a time. I expected this to be at least a little useful by grouping the reads from memory together and slightly increasing the number of instructions between a read and the corresponding use of that data. This provided no speedup.
I would like to be able to skip the PMULHUW, PADDUSW sequence if either of the operands is all-zeroes. However, MOVDQA does not set the zero register and I can't see any easy way to test that. Is it possible to jump if an XMM register is all zeroes?
Did I screw up any of the assembly loop code?
Any other ways to get one of these running faster?