From 074741aee4e03b3508da28db566ddc02b8303b60 Mon Sep 17 00:00:00 2001 From: Quim Date: Wed, 22 Jan 2025 15:59:37 +0000 Subject: [PATCH 1/3] Add AVX2/AVX512 vectorization of the WFA extend --- .../WFA2-lib/wavefront/wavefront_extend.c | 4 - .../wavefront/wavefront_extend_kernels.c | 111 +++- .../wavefront/wavefront_extend_kernels_avx.c | 598 ++++++++++++++++-- .../wavefront/wavefront_extend_kernels_avx.h | 35 +- 4 files changed, 679 insertions(+), 69 deletions(-) diff --git a/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend.c b/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend.c index 1ba78866..77ede3f3 100644 --- a/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend.c +++ b/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend.c @@ -54,11 +54,7 @@ void wavefront_extend_end2end_dispatcher_seq( wavefront_sequences_t* const seqs = &wf_aligner->sequences; // Check the sequence mode if (seqs->mode == wf_sequences_ascii) { -//#if __AVX2__ // TODO -// wavefront_extend_matches_packed_end2end_avx2(wf_aligner,mwavefront,lo,hi); -//#else wavefront_extend_matches_packed_end2end(wf_aligner,mwavefront,lo,hi); -//#endif } else { wf_offset_t dummy; wavefront_extend_matches_custom(wf_aligner,mwavefront,score,lo,hi,false,&dummy); diff --git a/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend_kernels.c b/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend_kernels.c index fe478693..04b98875 100644 --- a/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend_kernels.c +++ b/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend_kernels.c @@ -29,13 +29,43 @@ * DESCRIPTION: WFA module for the "extension" of exact matches */ + +// Use cross-platform header +#include + #include "wavefront_extend_kernels.h" #include "wavefront_termination.h" +#include "wavefront_extend_kernels_avx.h" + +#if __BYTE_ORDER == __LITTLE_ENDIAN +#define wavefront_extend_matches_kernel wavefront_extend_matches_kernel_blockwise +#else +#define wavefront_extend_matches_kernel wavefront_extend_matches_kernel_charwise +#endif /* - * Inner-most extend kernel (blockwise comparisons) + * Inner-most extend kernel */ -FORCE_INLINE wf_offset_t wavefront_extend_matches_packed_kernel( +FORCE_INLINE wf_offset_t wavefront_extend_matches_kernel_charwise( + wavefront_aligner_t* const wf_aligner, + const int k, + wf_offset_t offset) { + // Fetch pattern/text + char* pattern_ptr = wf_aligner->sequences.pattern + WAVEFRONT_V(k,offset); + char* text_ptr = wf_aligner->sequences.text + WAVEFRONT_H(k,offset); + // Compare 64-bits blocks + while (*pattern_ptr == *text_ptr) { + // Increment offset + offset++; + // Next chars + ++pattern_ptr; + ++text_ptr; + } + // Return extended offset + return offset; +} + +FORCE_INLINE wf_offset_t wavefront_extend_matches_kernel_blockwise( wavefront_aligner_t* const wf_aligner, const int k, wf_offset_t offset) { @@ -60,6 +90,7 @@ FORCE_INLINE wf_offset_t wavefront_extend_matches_packed_kernel( // Return extended offset return offset; } + /* * Wavefront-Extend Inner Kernels * Wavefront offset extension comparing characters @@ -72,42 +103,69 @@ FORCE_NO_INLINE void wavefront_extend_matches_packed_end2end( wavefront_t* const mwavefront, const int lo, const int hi) { - wf_offset_t* const offsets = mwavefront->offsets; - int k; - for (k=lo;k<=hi;++k) { - // Fetch offset - const wf_offset_t offset = offsets[k]; - if (offset == WAVEFRONT_OFFSET_NULL) continue; - // Extend offset - offsets[k] = wavefront_extend_matches_packed_kernel(wf_aligner,k,offset); - } + #if __AVX2__ && __BYTE_ORDER == __LITTLE_ENDIAN + #if __AVX512CD__ && __AVX512VL__ + wavefront_extend_matches_packed_end2end_avx512(wf_aligner, mwavefront, lo, hi); + #else + wavefront_extend_matches_packed_end2end_avx2(wf_aligner, mwavefront, lo, hi); + #endif + #else + wf_offset_t* const offsets = mwavefront->offsets; + int k; + for (k=lo;k<=hi;++k) { + // Fetch offset + const wf_offset_t offset = offsets[k]; + if (offset == WAVEFRONT_OFFSET_NULL) continue; + // Extend offset + offsets[k] = wavefront_extend_matches_kernel(wf_aligner,k,offset); + } + #endif } + FORCE_NO_INLINE wf_offset_t wavefront_extend_matches_packed_end2end_max( wavefront_aligner_t* const wf_aligner, wavefront_t* const mwavefront, const int lo, const int hi) { - wf_offset_t* const offsets = mwavefront->offsets; - wf_offset_t max_antidiag = 0; - int k; - for (k=lo;k<=hi;++k) { - // Fetch offset - const wf_offset_t offset = offsets[k]; - if (offset == WAVEFRONT_OFFSET_NULL) continue; - // Extend offset - offsets[k] = wavefront_extend_matches_packed_kernel(wf_aligner,k,offset); - // Compute max - const wf_offset_t antidiag = WAVEFRONT_ANTIDIAGONAL(k,offsets[k]); - if (max_antidiag < antidiag) max_antidiag = antidiag; - } - return max_antidiag; + #if __AVX2__ && __BYTE_ORDER == __LITTLE_ENDIAN + #if __AVX512CD__ && __AVX512VL__ + //printf("AVX512\n"); + return wavefront_extend_matches_packed_end2end_max_avx512(wf_aligner, mwavefront, lo, hi); + #else + //printf("AVX2\n"); + return wavefront_extend_matches_packed_end2end_max_avx2(wf_aligner, mwavefront, lo, hi); + #endif + #else + wf_offset_t* const offsets = mwavefront->offsets; + wf_offset_t max_antidiag = 0; + int k; + for (k=lo;k<=hi;++k) { + // Fetch offset + const wf_offset_t offset = offsets[k]; + if (offset == WAVEFRONT_OFFSET_NULL) continue; + // Extend offset + offsets[k] = wavefront_extend_matches_kernel(wf_aligner,k,offset); + // Compute max + const wf_offset_t antidiag = WAVEFRONT_ANTIDIAGONAL(k,offsets[k]); + if (max_antidiag < antidiag) max_antidiag = antidiag; + } + return max_antidiag; + #endif } + FORCE_NO_INLINE bool wavefront_extend_matches_packed_endsfree( wavefront_aligner_t* const wf_aligner, wavefront_t* const mwavefront, const int score, const int lo, const int hi) { + #if __AVX2__ && __BYTE_ORDER == __LITTLE_ENDIAN + #if __AVX512CD__ && __AVX512VL__ + return wavefront_extend_matches_packed_endsfree_avx512(wf_aligner, mwavefront, score, lo, hi); + #else + return wavefront_extend_matches_packed_endsfree_avx2(wf_aligner, mwavefront, score, lo, hi); + #endif + #else // Parameters wf_offset_t* const offsets = mwavefront->offsets; int k; @@ -116,7 +174,7 @@ FORCE_NO_INLINE bool wavefront_extend_matches_packed_endsfree( wf_offset_t offset = offsets[k]; if (offset == WAVEFRONT_OFFSET_NULL) continue; // Extend offset - offset = wavefront_extend_matches_packed_kernel(wf_aligner,k,offset); + offset = wavefront_extend_matches_kernel(wf_aligner,k,offset); offsets[k] = offset; // Check ends-free reaching boundaries if (wavefront_termination_endsfree(wf_aligner,mwavefront,score,k,offset)) { @@ -134,6 +192,7 @@ FORCE_NO_INLINE bool wavefront_extend_matches_packed_endsfree( } // Alignment not finished return false; + #endif } /* * Wavefront-Extend Inner Kernel (Custom match function) diff --git a/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend_kernels_avx.c b/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend_kernels_avx.c index b31d8d55..6079d9e6 100644 --- a/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend_kernels_avx.c +++ b/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend_kernels_avx.c @@ -35,6 +35,7 @@ #include "wavefront_heuristic.h" #include "wavefront_extend_kernels.h" #include "wavefront_extend_kernels_avx.h" +#include "wavefront_termination.h" #if __AVX2__ #include @@ -66,25 +67,29 @@ FORCE_INLINE wf_offset_t wavefront_extend_matches_packed_kernel( // Return extended offset return offset; } + /* * SIMD clz, use a native instruction when available (AVX512 CD or VL * extensions), or emulate the clz behavior. */ FORCE_INLINE __m256i avx2_lzcnt_epi32(__m256i v) { -#if __AVX512CD__ && __AVX512VL__ - return _mm256_lzcnt_epi32(v); -#else - // Emulate clz for AVX2: https://stackoverflow.com/a/58827596 - v = _mm256_andnot_si256(_mm256_srli_epi32(v,8),v); // keep 8 MSB - v = _mm256_castps_si256(_mm256_cvtepi32_ps(v)); // convert an integer to float - v = _mm256_srli_epi32(v,23); // shift down the exponent - v = _mm256_subs_epu16(_mm256_set1_epi32(158),v); // undo bias - v = _mm256_min_epi16(v,_mm256_set1_epi32(32)); // clamp at 32 - return v; -#endif + #if __AVX512CD__ && __AVX512VL__ + return _mm256_lzcnt_epi32(v); + #else + // Emulate clz for AVX2: https://stackoverflow.com/a/58827596 + v = _mm256_andnot_si256(_mm256_srli_epi32(v,8),v); // keep 8 MSB + v = _mm256_castps_si256(_mm256_cvtepi32_ps(v)); // convert an integer to float + v = _mm256_srli_epi32(v,23); // shift down the exponent + v = _mm256_subs_epu16(_mm256_set1_epi32(158),v); // undo bias + v = _mm256_min_epi16(v,_mm256_set1_epi32(32)); // clamp at 32 + return v; + #endif } + + + /* - * Wavefront-Extend Inner Kernel (SIMD AVX2/AVX512) + * Wavefront-Extend Inner Kernel (SIMD AVX2) */ FORCE_NO_INLINE void wavefront_extend_matches_packed_end2end_avx2( wavefront_aligner_t* const wf_aligner, @@ -98,7 +103,6 @@ FORCE_NO_INLINE void wavefront_extend_matches_packed_end2end_avx2( const char* pattern = wf_aligner->sequences.pattern; const char* text = wf_aligner->sequences.text; const __m256i vector_null = _mm256_set1_epi32(-1); - const __m256i fours = _mm256_set1_epi32(4); const __m256i eights = _mm256_set1_epi32(8); const __m256i vecShuffle = _mm256_set_epi8(28,29,30,31,24,25,26,27, 20,21,22,23,16,17,18,19, @@ -115,40 +119,41 @@ FORCE_NO_INLINE void wavefront_extend_matches_packed_end2end_avx2( offsets[k] = wavefront_extend_matches_packed_kernel(wf_aligner,k,offset); } if (num_of_diagonals < elems_per_register) return; + k_min += loop_peeling_iters; __m256i ks = _mm256_set_epi32 ( k_min+7,k_min+6,k_min+5,k_min+4, k_min+3,k_min+2,k_min+1,k_min); - // Main SIMD extension loop - for (k=k_min;k<=k_max;k+=elems_per_register) { + + + for (k=k_min; k<=k_max; k+=elems_per_register) { __m256i offsets_vector = _mm256_lddqu_si256 ((__m256i*)&offsets[k]); - __m256i h_vector = offsets_vector; - __m256i v_vector = _mm256_sub_epi32(offsets_vector,ks); - ks =_mm256_add_epi32 (ks, eights); + __m256i h_vector = offsets_vector; + __m256i v_vector = _mm256_sub_epi32(offsets_vector,ks); + // NULL offsets will read at index 0 (avoid segfaults) - __m256i null_mask = _mm256_cmpgt_epi32(offsets_vector,vector_null); - v_vector = _mm256_and_si256(null_mask,v_vector); - h_vector = _mm256_and_si256(null_mask,h_vector); + __m256i null_mask = _mm256_cmpgt_epi32(offsets_vector, vector_null); + v_vector = _mm256_and_si256(null_mask, v_vector); + h_vector = _mm256_and_si256(null_mask, h_vector); + __m256i pattern_vector = _mm256_i32gather_epi32((int const*)&pattern[0],v_vector,1); - __m256i text_vector = _mm256_i32gather_epi32((int const*)&text[0],h_vector,1); - // Change endianess to make the xor + clz character comparison - pattern_vector = _mm256_shuffle_epi8(pattern_vector,vecShuffle); - text_vector = _mm256_shuffle_epi8(text_vector,vecShuffle); + __m256i text_vector = _mm256_i32gather_epi32((int const*)&text[0],h_vector,1); + __m256i vector_mask = _mm256_cmpeq_epi32(text_vector, pattern_vector); + int mask = _mm256_movemask_epi8(vector_mask); + __m256i xor_result_vector = _mm256_xor_si256(pattern_vector,text_vector); - __m256i clz_vector = avx2_lzcnt_epi32(xor_result_vector); - // Divide clz by 8 to get the number of equal characters - // Assume there are sentinels on sequences so we won't count characters - // outside the sequences - __m256i equal_chars = _mm256_srli_epi32(clz_vector,3); - offsets_vector = _mm256_add_epi32 (offsets_vector,equal_chars); - v_vector = _mm256_add_epi32 (v_vector,fours); - h_vector = _mm256_add_epi32 (h_vector,fours); - // Lanes to continue == 0xffffffff, other lanes = 0 - __m256i vector_mask = _mm256_cmpeq_epi32(equal_chars,fours); + xor_result_vector = _mm256_shuffle_epi8(xor_result_vector, vecShuffle); + __m256i clz_vector = avx2_lzcnt_epi32(xor_result_vector); + + __m256i equal_chars = _mm256_srli_epi32(clz_vector,3); + //equal_chars = _mm256_and_si256(null_mask, equal_chars); + offsets_vector = _mm256_add_epi32 (offsets_vector,equal_chars); + ks = _mm256_add_epi32 (ks, eights); + _mm256_storeu_si256((__m256i*)&offsets[k],offsets_vector); - int mask = _mm256_movemask_epi8(vector_mask); + if(mask == 0) continue; - // ctz(0) is undefined + while (mask != 0) { int tz = __builtin_ctz(mask); int curr_k = k + (tz/4); @@ -164,4 +169,521 @@ FORCE_NO_INLINE void wavefront_extend_matches_packed_end2end_avx2( } } -#endif // AVX2 + +FORCE_NO_INLINE wf_offset_t wavefront_extend_matches_packed_end2end_max_avx2( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int lo, + const int hi) { + // Parameters + + const int elems_per_register = 8; + const char* pattern = wf_aligner->sequences.pattern; + const char* text = wf_aligner->sequences.text; + + wf_offset_t* const offsets = mwavefront->offsets; + wf_offset_t max_antidiag = 0; + + int k_min = lo; + int k_max = hi; + + int num_of_diagonals = k_max - k_min + 1; + int loop_peeling_iters = num_of_diagonals % elems_per_register; + int k; + + const __m256i vector_null = _mm256_set1_epi32(-1); + const __m256i eights = _mm256_set1_epi32(8); + const __m256i vecShuffle = _mm256_set_epi8(28,29,30,31,24,25,26,27, + 20,21,22,23,16,17,18,19, + 12,13,14,15, 8, 9,10,11, + 4 , 5, 6, 7, 0, 1, 2 ,3); + + + for (k=k_min;k= 0) { + offset = wavefront_extend_matches_packed_kernel(wf_aligner,curr_k,offset); + offsets[curr_k] = offset; + const wf_offset_t antidiag = WAVEFRONT_ANTIDIAGONAL(curr_k, offset); + if (max_antidiag < antidiag) max_antidiag = antidiag; + } else { + offsets[curr_k] = WAVEFRONT_OFFSET_NULL; + } + mask &= (0xfffffff0 << tz); + } + max_antidiag_v = _mm256_set1_epi32(max_antidiag); + } + } + + const wf_offset_t max_antidiagonal_buffer[8]; + _mm256_storeu_si256((__m256i*)&max_antidiagonal_buffer[0], max_antidiag_v); + for (int i = 0; i < 8; i++) + { + const wf_offset_t antidiag = max_antidiagonal_buffer[i]; + if (max_antidiag < antidiag) max_antidiag = antidiag; + } + return max_antidiag; +} + + +FORCE_NO_INLINE bool wavefront_extend_matches_packed_endsfree_avx2( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int score, + const int lo, + const int hi) { + // Parameters + + const int elems_per_register = 8; + const char* pattern = wf_aligner->sequences.pattern; + const char* text = wf_aligner->sequences.text; + + wf_offset_t* const offsets = mwavefront->offsets; + + int k_min = lo; + int k_max = hi; + + int num_of_diagonals = k_max - k_min + 1; + int loop_peeling_iters = num_of_diagonals % elems_per_register; + int k; + + const __m256i vector_null = _mm256_set1_epi32(-1); + const __m256i eights = _mm256_set1_epi32(8); + const __m256i vecShuffle = _mm256_set_epi8(28,29,30,31,24,25,26,27, + 20,21,22,23,16,17,18,19, + 12,13,14,15, 8, 9,10,11, + 4 , 5, 6, 7, 0, 1, 2 ,3); + + for (k=k_min;k= 0) { + offsets[curr_k] = wavefront_extend_matches_packed_kernel(wf_aligner,curr_k,offset); + if (wavefront_termination_endsfree(wf_aligner,mwavefront,score,curr_k,offsets[curr_k])) { + return true; // Quit (we are done) + } + } else { + offsets[curr_k] = WAVEFRONT_OFFSET_NULL; + } + mask &= (0xfffffff0 << tz); + } + } + for (k=k_min; k <= k_max; k++) { + const wf_offset_t offset = offsets[k]; + if (offset < 0) continue; + // Check ends-free reaching boundaries + if (wavefront_termination_endsfree(wf_aligner,mwavefront,score,k,offset)) { + return true; // Quit (we are done) + } + } + return false; +} + + +#if __AVX512CD__ && __AVX512VL__ +/* + * Wavefront-Extend Inner Kernel (SIMD AVX512) + */ +FORCE_NO_INLINE void wavefront_extend_matches_packed_end2end_avx512( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int lo, + const int hi) { + // Parameters + wf_offset_t* const offsets = mwavefront->offsets; + int k_min = lo; + int k_max = hi; + const char* pattern = wf_aligner->sequences.pattern; + const char* text = wf_aligner->sequences.text; + const __m512i zero_vector = _mm512_setzero_si512(); + const __m512i vector_null = _mm512_set1_epi32(-1); + const __m512i sixteens = _mm512_set1_epi32(16); + const __m512i vecShuffle = _mm512_set_epi8(60,61,62,63,56,57,58,59, + 52,53,54,55,48,49,50,51, + 44,45,46,47,40,41,42,43, + 36,37,38,39,32,33,34,35, + 28,29,30,31,24,25,26,27, + 20,21,22,23,16,17,18,19, + 12,13,14,15,8,9,10,11, + 4,5,6,7,0,1,2,3); + const int elems_per_register = 16; + int num_of_diagonals = k_max - k_min + 1; + int loop_peeling_iters = num_of_diagonals % elems_per_register; + int k; + for (k=k_min;k> i) & 1) == 0) continue; + const int curr_k = k + i; + const wf_offset_t offset = offsets[curr_k]; + if (offset >= 0) { + offsets[curr_k] = wavefront_extend_matches_packed_kernel(wf_aligner,curr_k,offset); + } else { + offsets[curr_k] = WAVEFRONT_OFFSET_NULL; + } + } + } +} + + +FORCE_NO_INLINE wf_offset_t wavefront_extend_matches_packed_end2end_max_avx512( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int lo, + const int hi) { + // Parameters + + const int elems_per_register = 16; + const char* pattern = wf_aligner->sequences.pattern; + const char* text = wf_aligner->sequences.text; + + wf_offset_t* const offsets = mwavefront->offsets; + wf_offset_t max_antidiag = 0; + + int k_min = lo; + int k_max = hi; + + int num_of_diagonals = k_max - k_min + 1; + int loop_peeling_iters = num_of_diagonals % elems_per_register; + int k; + + const __m512i zero_vector = _mm512_setzero_si512(); + const __m512i vector_null = _mm512_set1_epi32(-1); + const __m512i sixteens = _mm512_set1_epi32(16); + const __m512i vecShuffle = _mm512_set_epi8(60,61,62,63,56,57,58,59, + 52,53,54,55,48,49,50,51, + 44,45,46,47,40,41,42,43, + 36,37,38,39,32,33,34,35, + 28,29,30,31,24,25,26,27, + 20,21,22,23,16,17,18,19, + 12,13,14,15,8,9,10,11, + 4,5,6,7,0,1,2,3); + + for (k=k_min;k> i) & 1) == 0) continue; + const int curr_k = k + i; + wf_offset_t offset = offsets[curr_k]; + if (offset >= 0) + { + offset = wavefront_extend_matches_packed_kernel(wf_aligner,curr_k,offset); + offsets[curr_k] = offset; + const wf_offset_t antidiag = WAVEFRONT_ANTIDIAGONAL(curr_k, offset); + if (max_antidiag < antidiag) max_antidiag = antidiag; + } + else + { + offsets[curr_k] = WAVEFRONT_OFFSET_NULL; + } + } + max_antidiag_v = _mm512_set1_epi32(max_antidiag); + } + + return _mm512_reduce_max_epi32(max_antidiag_v); +} + + +FORCE_NO_INLINE bool wavefront_extend_matches_packed_endsfree_avx512( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int score, + const int lo, + const int hi) { + + // Parameters + wf_offset_t* const offsets = mwavefront->offsets; + int k_min = lo; + int k_max = hi; + const char* pattern = wf_aligner->sequences.pattern; + const char* text = wf_aligner->sequences.text; + const __m512i zero_vector = _mm512_setzero_si512(); + const __m512i vector_null = _mm512_set1_epi32(-1); + const __m512i sixteens = _mm512_set1_epi32(16); + const __m512i vecShuffle = _mm512_set_epi8(60,61,62,63,56,57,58,59, + 52,53,54,55,48,49,50,51, + 44,45,46,47,40,41,42,43, + 36,37,38,39,32,33,34,35, + 28,29,30,31,24,25,26,27, + 20,21,22,23,16,17,18,19, + 12,13,14,15,8,9,10,11, + 4,5,6,7,0,1,2,3); + const int elems_per_register = 16; + int num_of_diagonals = k_max - k_min + 1; + int loop_peeling_iters = num_of_diagonals % elems_per_register; + int k; + + for (k=k_min;k> i) & 1) == 0) continue; + const int curr_k = k + i; + const wf_offset_t offset = offsets[curr_k]; + if (offset >= 0) + { + offsets[curr_k] = wavefront_extend_matches_packed_kernel(wf_aligner,curr_k,offset); + if (wavefront_termination_endsfree(wf_aligner,mwavefront,score,curr_k,offsets[curr_k])) { + return true; // Quit (we are done) + } + } + else + { + offsets[curr_k] = WAVEFRONT_OFFSET_NULL; + } + } + } + + for (k=k_min; k <= k_max; k++) { + const wf_offset_t offset = offsets[k]; + if (offset < 0) continue; + // Check ends-free reaching boundaries + if (wavefront_termination_endsfree(wf_aligner,mwavefront,score,k,offset)) { + return true; // Quit (we are done) + } + } + return false; +} +#endif + +#endif // AVX2 \ No newline at end of file diff --git a/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend_kernels_avx.h b/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend_kernels_avx.h index 0e932d4e..dd936866 100644 --- a/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend_kernels_avx.h +++ b/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_extend_kernels_avx.h @@ -42,6 +42,39 @@ void wavefront_extend_matches_packed_end2end_avx2( const int lo, const int hi); +wf_offset_t wavefront_extend_matches_packed_end2end_max_avx2( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int lo, + const int hi); + +bool wavefront_extend_matches_packed_endsfree_avx2( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int score, + const int lo, + const int hi); + +#if __AVX512CD__ && __AVX512VL__ +void wavefront_extend_matches_packed_end2end_avx512( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int lo, + const int hi); + +wf_offset_t wavefront_extend_matches_packed_end2end_max_avx512( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int lo, + const int hi); + +bool wavefront_extend_matches_packed_endsfree_avx512( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int score, + const int lo, + const int hi); +#endif #endif // AVX2 -#endif /* WAVEFRONT_EXTEND_AVX_H_ */ +#endif /* WAVEFRONT_EXTEND_AVX_H_ */ \ No newline at end of file From 2752bd46b8a51c160665abad43080e9ce483bb8e Mon Sep 17 00:00:00 2001 From: Quim Date: Wed, 22 Jan 2025 16:00:34 +0000 Subject: [PATCH 2/3] Add AVX512 vectorization of breakpoint_i2i and m2m --- .../WFA2-lib/wavefront/wavefront_bialign.c | 267 +++++++++++++++++- 1 file changed, 263 insertions(+), 4 deletions(-) diff --git a/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_bialign.c b/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_bialign.c index 7faca6b0..986bce95 100644 --- a/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_bialign.c +++ b/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_bialign.c @@ -42,6 +42,10 @@ #include "wavefront_plot.h" #include "wavefront_debug.h" +#if __AVX2__ +#include +#endif + /* * Config */ @@ -195,6 +199,8 @@ void wavefront_bialign_breakpoint_indel2indel( wavefront_t* const dwf_1, const affine2p_matrix_type component, wf_bialign_breakpoint_t* const breakpoint) { +#if __AVX2__ && __AVX512CD__ && __AVX512VL__ + // AVX512 implementation of the bialign_breakpoint_indel2indel // Parameters wavefront_sequences_t* const sequences = &wf_aligner->sequences; const int text_length = sequences->text_length; @@ -211,6 +217,131 @@ void wavefront_bialign_breakpoint_indel2indel( // Compute overlapping interval const int min_hi = MIN(hi_0,hi_1); const int max_lo = MAX(lo_0,lo_1); + + if (score_0 + score_1 - gap_open >= breakpoint->score) return; + + const int elems_per_register = 16; + int k_0 = max_lo; + int num_diagonals = min_hi - max_lo + 1; + int loop_peeling_iters = num_diagonals % elems_per_register; + + // Scalar pass to peel off the first few iterations, and make the main loop a + // multiple of the register size + for (;k_0offsets[k_0]; + const wf_offset_t doffset_1 = dwf_1->offsets[k_1]; + const int dh_0 = WAVEFRONT_H(k_0,doffset_0); + const int dh_1 = WAVEFRONT_H(k_1,doffset_1); + // Check breakpoint d2d + if (dh_0 + dh_1 >= text_length && dh_0 <= text_length && dh_1 <= text_length) { + if (breakpoint_forward) { + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = dh_0; + breakpoint->offset_reverse = dh_1; + } else { + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = dh_1; + breakpoint->offset_reverse = dh_0; + } + breakpoint->score = score_0 + score_1 - gap_open; + breakpoint->component = component; + // wavefront_bialign_debug(breakpoint,-1); // DEBUG + // No need to keep searching + return; + } + } + // Finish the remaining iterations in a vectorized manner + const __m512i sixteens = _mm512_set1_epi32(16); + const __m512i tlens = _mm512_set1_epi32(text_length); + const __m512i rev = _mm512_setr_epi32(15,14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + const __m512i min_hi_vector = _mm512_set1_epi32(min_hi); + __m512i ks = _mm512_set_epi32 ( + k_0+15,k_0+14,k_0+13,k_0+12,k_0+11,k_0+10,k_0+9,k_0+8, + k_0+7,k_0+6,k_0+5,k_0+4,k_0+3,k_0+2,k_0+1,k_0); + for (;k_0<=min_hi;k_0+=elems_per_register) { + const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); + // Fetch offsets + __m512i doffsets_0 = _mm512_loadu_si512((__m512i*)&dwf_0->offsets[k_0]); + __m512i doffsets_1 = _mm512_loadu_si512((__m512i*)&dwf_1->offsets[k_1-elems_per_register+1]); + // doffsets_1 are in reverse order, so we need to reverse them + doffsets_1 = _mm512_permutexvar_epi32(rev, doffsets_1); + __m512i dh_0s = doffsets_0; + __m512i dh_1s = doffsets_1; + __mmask16 bp_found_mask =_mm512_cmpge_epi32_mask(_mm512_add_epi32(dh_0s, dh_1s), tlens); + bp_found_mask = _mm512_mask_cmple_epi32_mask(bp_found_mask, ks, min_hi_vector); + bp_found_mask = _mm512_mask_cmple_epi32_mask(bp_found_mask, dh_0s, tlens); + bp_found_mask = _mm512_mask_cmple_epi32_mask(bp_found_mask, dh_1s, tlens); + + if (bp_found_mask) { + // A breakpoint has been found! Check in which exact diagonal it is + // This can be done directly from the mask and vector registers, for now, + // it is implemented like the scalar implementation. This only happens + // when a BP is found, so it should not be a bottleneck. + int initial_k0 = k_0; + for (;k_0offsets[k_0]; + const wf_offset_t doffset_1 = dwf_1->offsets[k_1]; + const int dh_0 = WAVEFRONT_H(k_0,doffset_0); + const int dh_1 = WAVEFRONT_H(k_1,doffset_1); + // Check breakpoint d2d + if (dh_0 + dh_1 >= text_length && dh_0 <= text_length && dh_1 <= text_length) { + if (breakpoint_forward) { + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = dh_0; + breakpoint->offset_reverse = dh_1; + } else { + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = dh_1; + breakpoint->offset_reverse = dh_0; + } + breakpoint->score = score_0 + score_1 - gap_open; + breakpoint->component = component; + // wavefront_bialign_debug(breakpoint,-1); // DEBUG + // No need to keep searching + return; + } + } + } + // Update ks for the next iteration + ks = _mm512_add_epi32(ks, sixteens); + } +#else // Scalar implementation of the bialign_breakpoint_indel2indel + // Parameters + wavefront_sequences_t* const sequences = &wf_aligner->sequences; + const int text_length = sequences->text_length; + const int pattern_length = sequences->pattern_length; + const int gap_open = + (component==affine2p_matrix_I1 || component==affine2p_matrix_D1) ? + wf_aligner->penalties.gap_opening1 : wf_aligner->penalties.gap_opening2; + + + // Check wavefronts overlapping + const int lo_0 = dwf_0->lo; + const int hi_0 = dwf_0->hi; + const int lo_1 = WAVEFRONT_K_INVERSE(dwf_1->hi,pattern_length,text_length); + const int hi_1 = WAVEFRONT_K_INVERSE(dwf_1->lo,pattern_length,text_length); + if (hi_1 < lo_0 || hi_0 < lo_1) return; + // Compute overlapping interval + const int min_hi = MIN(hi_0,hi_1); + const int max_lo = MAX(lo_0,lo_1); + + if (score_0 + score_1 - gap_open >= breakpoint->score) return; + int k_0; for (k_0=max_lo;k_0<=min_hi;k_0++) { const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); @@ -220,9 +351,13 @@ void wavefront_bialign_breakpoint_indel2indel( const int dh_0 = WAVEFRONT_H(k_0,doffset_0); const int dh_1 = WAVEFRONT_H(k_1,doffset_1); // Check breakpoint d2d - if (dh_0 + dh_1 >= text_length && score_0 + score_1 - gap_open < breakpoint->score && - dh_0 <= text_length && dh_1 <= text_length) { + if (dh_0 + dh_1 >= text_length) { if (breakpoint_forward) { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_0,dh_0); + const int h = WAVEFRONT_H(k_0,dh_0); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint breakpoint->score_forward = score_0; breakpoint->score_reverse = score_1; breakpoint->k_forward = k_0; @@ -230,6 +365,11 @@ void wavefront_bialign_breakpoint_indel2indel( breakpoint->offset_forward = dh_0; breakpoint->offset_reverse = dh_1; } else { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_1,dh_1); + const int h = WAVEFRONT_H(k_1,dh_1); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint breakpoint->score_forward = score_1; breakpoint->score_reverse = score_0; breakpoint->k_forward = k_1; @@ -243,7 +383,9 @@ void wavefront_bialign_breakpoint_indel2indel( // No need to keep searching return; } + } +#endif // AVX512 } void wavefront_bialign_breakpoint_m2m( wavefront_aligner_t* const wf_aligner, @@ -254,6 +396,122 @@ void wavefront_bialign_breakpoint_m2m( wavefront_t* const mwf_1, wf_bialign_breakpoint_t* const breakpoint) { // Parameters + if (score_0 + score_1 >= breakpoint->score) return; +#if __AVX2__ && __AVX512CD__ && __AVX512VL__ + // AVX512 implementation of the bialign_breakpoint_indel2indel + // Parameters + wavefront_sequences_t* const sequences = &wf_aligner->sequences; + const int text_length = sequences->text_length; + const int pattern_length = sequences->pattern_length; + // Check wavefronts overlapping + const int lo_0 = mwf_0->lo; + const int hi_0 = mwf_0->hi; + const int lo_1 = WAVEFRONT_K_INVERSE(mwf_1->hi,pattern_length,text_length); + const int hi_1 = WAVEFRONT_K_INVERSE(mwf_1->lo,pattern_length,text_length); + if (hi_1 < lo_0 || hi_0 < lo_1) return; + // Compute overlapping interval + const int min_hi = MIN(hi_0,hi_1); + const int max_lo = MAX(lo_0,lo_1); + + const int elems_per_register = 16; + int k_0 = max_lo; + int num_diagonals = min_hi - max_lo + 1; + int loop_peeling_iters = num_diagonals % elems_per_register; + + // Scalar pass to peel off the first few iterations, and make the main loop a + // multiple of the register size + for (;k_0offsets[k_0]; + const wf_offset_t moffset_1 = mwf_1->offsets[k_1]; + const int mh_0 = WAVEFRONT_H(k_0,moffset_0); + const int mh_1 = WAVEFRONT_H(k_1,moffset_1); + // Check breakpoint m2m + if (mh_0 + mh_1 >= text_length) { + if (breakpoint_forward) { + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = moffset_0; + breakpoint->offset_reverse = moffset_1; + } else { + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = moffset_1; + breakpoint->offset_reverse = moffset_0; + } + breakpoint->score = score_0 + score_1; + breakpoint->component = affine2p_matrix_M; + // wavefront_bialign_debug(breakpoint,-1); // DEBUG + // No need to keep searching + return; + } + } + // Finish the remaining iterations in a vectorized manner + const __m512i sixteens = _mm512_set1_epi32(16); + const __m512i tlens = _mm512_set1_epi32(text_length); + const __m512i rev = _mm512_setr_epi32(15,14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + const __m512i min_hi_vector = _mm512_set1_epi32(min_hi); + __m512i ks = _mm512_set_epi32 ( + k_0+15,k_0+14,k_0+13,k_0+12,k_0+11,k_0+10,k_0+9,k_0+8, + k_0+7,k_0+6,k_0+5,k_0+4,k_0+3,k_0+2,k_0+1,k_0); + for (;k_0<=min_hi;k_0+=elems_per_register) { + const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); + // Fetch offsets + __m512i doffsets_0 = _mm512_loadu_si512((__m512i*)&mwf_0->offsets[k_0]); + __m512i doffsets_1 = _mm512_loadu_si512((__m512i*)&mwf_1->offsets[k_1-elems_per_register+1]); + // doffsets_1 are in reverse order, so we need to reverse them + doffsets_1 = _mm512_permutexvar_epi32(rev, doffsets_1); + __m512i dh_0s = doffsets_0; + __m512i dh_1s = doffsets_1; + __mmask16 bp_found_mask =_mm512_cmpge_epi32_mask(_mm512_add_epi32(dh_0s, dh_1s), tlens); + + if (bp_found_mask) { + // A breakpoint has been found! Check in which exact diagonal it is + // This can be done directly from the mask and vector registers, for now, + // it is implemented like the scalar implementation. This only happens + // when a BP is found, so it should not be a bottleneck. + int initial_k0 = k_0; + for (;k_0offsets[k_0]; + const wf_offset_t moffset_1 = mwf_1->offsets[k_1]; + const int mh_0 = WAVEFRONT_H(k_0,moffset_0); + const int mh_1 = WAVEFRONT_H(k_1,moffset_1); + // Check breakpoint m2m + if (mh_0 + mh_1 >= text_length) { + if (breakpoint_forward) { + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = moffset_0; + breakpoint->offset_reverse = moffset_1; + } else { + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = moffset_1; + breakpoint->offset_reverse = moffset_0; + } + breakpoint->score = score_0 + score_1; + breakpoint->component = affine2p_matrix_M; + // wavefront_bialign_debug(breakpoint,-1); // DEBUG + // No need to keep searching + return; + } + } + } + // Update ks for the next iteration + ks = _mm512_add_epi32(ks, sixteens); + } +#else // Scalar implementation of the bialign_breakpoint_indel2indel wavefront_sequences_t* const sequences = &wf_aligner->sequences; const int text_length = sequences->text_length; const int pattern_length = sequences->pattern_length; @@ -275,7 +533,7 @@ void wavefront_bialign_breakpoint_m2m( const int mh_0 = WAVEFRONT_H(k_0,moffset_0); const int mh_1 = WAVEFRONT_H(k_1,moffset_1); // Check breakpoint m2m - if (mh_0 + mh_1 >= text_length && score_0 + score_1 < breakpoint->score) { + if (mh_0 + mh_1 >= text_length) { if (breakpoint_forward) { breakpoint->score_forward = score_0; breakpoint->score_reverse = score_1; @@ -298,6 +556,7 @@ void wavefront_bialign_breakpoint_m2m( return; } } +#endif // AVX512 } /* * Bidirectional find overlaps @@ -718,4 +977,4 @@ void wavefront_bialign( } else { // Other cases wf_aligner->align_status.status = WF_STATUS_UNATTAINABLE; } -} +} \ No newline at end of file From 180dc1409beef182393f6051f4fc7b6e6af25d1b Mon Sep 17 00:00:00 2001 From: Quim Date: Thu, 23 Jan 2025 13:26:41 +0000 Subject: [PATCH 3/3] Improved breakpoint i2i and m2m vectorization --- .../WFA2-lib/wavefront/wavefront_bialign.c | 456 ++++++++++++++---- 1 file changed, 375 insertions(+), 81 deletions(-) diff --git a/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_bialign.c b/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_bialign.c index 986bce95..a1205fd8 100644 --- a/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_bialign.c +++ b/src/common/wflign/deps/WFA2-lib/wavefront/wavefront_bialign.c @@ -27,12 +27,10 @@ * PROJECT: Wavefront Alignment Algorithms * AUTHOR(S): Santiago Marco-Sola */ - #include "utils/commons.h" #include "wavefront_bialign.h" #include "wavefront_unialign.h" #include "wavefront_bialigner.h" - #include "wavefront_compute.h" #include "wavefront_compute_affine.h" #include "wavefront_compute_affine2p.h" @@ -41,18 +39,15 @@ #include "wavefront_extend.h" #include "wavefront_plot.h" #include "wavefront_debug.h" - #if __AVX2__ #include #endif - /* * Config */ #define WF_BIALIGN_FALLBACK_MIN_SCORE 250 #define WF_BIALIGN_FALLBACK_MIN_LENGTH 100 #define WF_BIALIGN_RECOVERY_MIN_SCORE 500 - /* * Debug */ @@ -190,7 +185,9 @@ int wavefront_bialign_base( /* * Bidirectional check breakpoints */ -void wavefront_bialign_breakpoint_indel2indel( +#if __AVX2__ +#if __AVX512CD__ && __AVX512VL__ +void wavefront_bialign_breakpoint_indel2indel_avx512( wavefront_aligner_t* const wf_aligner, const bool breakpoint_forward, const int score_0, @@ -199,7 +196,6 @@ void wavefront_bialign_breakpoint_indel2indel( wavefront_t* const dwf_1, const affine2p_matrix_type component, wf_bialign_breakpoint_t* const breakpoint) { -#if __AVX2__ && __AVX512CD__ && __AVX512VL__ // AVX512 implementation of the bialign_breakpoint_indel2indel // Parameters wavefront_sequences_t* const sequences = &wf_aligner->sequences; @@ -217,14 +213,11 @@ void wavefront_bialign_breakpoint_indel2indel( // Compute overlapping interval const int min_hi = MIN(hi_0,hi_1); const int max_lo = MAX(lo_0,lo_1); - if (score_0 + score_1 - gap_open >= breakpoint->score) return; - const int elems_per_register = 16; int k_0 = max_lo; int num_diagonals = min_hi - max_lo + 1; int loop_peeling_iters = num_diagonals % elems_per_register; - // Scalar pass to peel off the first few iterations, and make the main loop a // multiple of the register size for (;k_0= text_length && dh_0 <= text_length && dh_1 <= text_length) { + if (dh_0 + dh_1 >= text_length) { if (breakpoint_forward) { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_0,dh_0); + const int h = WAVEFRONT_H(k_0,dh_0); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint breakpoint->score_forward = score_0; breakpoint->score_reverse = score_1; breakpoint->k_forward = k_0; @@ -244,6 +242,11 @@ void wavefront_bialign_breakpoint_indel2indel( breakpoint->offset_forward = dh_0; breakpoint->offset_reverse = dh_1; } else { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_1,dh_1); + const int h = WAVEFRONT_H(k_1,dh_1); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint breakpoint->score_forward = score_1; breakpoint->score_reverse = score_0; breakpoint->k_forward = k_1; @@ -259,13 +262,8 @@ void wavefront_bialign_breakpoint_indel2indel( } } // Finish the remaining iterations in a vectorized manner - const __m512i sixteens = _mm512_set1_epi32(16); const __m512i tlens = _mm512_set1_epi32(text_length); - const __m512i rev = _mm512_setr_epi32(15,14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); - const __m512i min_hi_vector = _mm512_set1_epi32(min_hi); - __m512i ks = _mm512_set_epi32 ( - k_0+15,k_0+14,k_0+13,k_0+12,k_0+11,k_0+10,k_0+9,k_0+8, - k_0+7,k_0+6,k_0+5,k_0+4,k_0+3,k_0+2,k_0+1,k_0); + const __m512i rev = _mm512_setr_epi32(15,14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); for (;k_0<=min_hi;k_0+=elems_per_register) { const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); // Fetch offsets @@ -273,54 +271,228 @@ void wavefront_bialign_breakpoint_indel2indel( __m512i doffsets_1 = _mm512_loadu_si512((__m512i*)&dwf_1->offsets[k_1-elems_per_register+1]); // doffsets_1 are in reverse order, so we need to reverse them doffsets_1 = _mm512_permutexvar_epi32(rev, doffsets_1); - __m512i dh_0s = doffsets_0; - __m512i dh_1s = doffsets_1; - __mmask16 bp_found_mask =_mm512_cmpge_epi32_mask(_mm512_add_epi32(dh_0s, dh_1s), tlens); - bp_found_mask = _mm512_mask_cmple_epi32_mask(bp_found_mask, ks, min_hi_vector); - bp_found_mask = _mm512_mask_cmple_epi32_mask(bp_found_mask, dh_0s, tlens); - bp_found_mask = _mm512_mask_cmple_epi32_mask(bp_found_mask, dh_1s, tlens); - - if (bp_found_mask) { - // A breakpoint has been found! Check in which exact diagonal it is - // This can be done directly from the mask and vector registers, for now, - // it is implemented like the scalar implementation. This only happens - // when a BP is found, so it should not be a bottleneck. - int initial_k0 = k_0; - for (;k_0offsets[k_0]; - const wf_offset_t doffset_1 = dwf_1->offsets[k_1]; - const int dh_0 = WAVEFRONT_H(k_0,doffset_0); - const int dh_1 = WAVEFRONT_H(k_1,doffset_1); - // Check breakpoint d2d - if (dh_0 + dh_1 >= text_length && dh_0 <= text_length && dh_1 <= text_length) { - if (breakpoint_forward) { - breakpoint->score_forward = score_0; - breakpoint->score_reverse = score_1; - breakpoint->k_forward = k_0; - breakpoint->k_reverse = k_1; - breakpoint->offset_forward = dh_0; - breakpoint->offset_reverse = dh_1; - } else { - breakpoint->score_forward = score_1; - breakpoint->score_reverse = score_0; - breakpoint->k_forward = k_1; - breakpoint->k_reverse = k_0; - breakpoint->offset_forward = dh_1; - breakpoint->offset_reverse = dh_0; - } + __mmask16 bp_found_mask =_mm512_cmpge_epi32_mask(_mm512_add_epi32(doffsets_0, doffsets_1), tlens); + if (__builtin_expect(bp_found_mask == 0, 1)) continue; + // A breakpoint has been found! Check in which exact diagonal it is + // This can be done directly from the mask and vector registers, for now, + // it is implemented like the scalar implementation. This only happens + // when a BP is found, so it should not be a bottleneck. + int initial_k0 = k_0; + for (;k_0offsets[k_0]; + const wf_offset_t doffset_1 = dwf_1->offsets[k_1]; + const int dh_0 = WAVEFRONT_H(k_0,doffset_0); + const int dh_1 = WAVEFRONT_H(k_1,doffset_1); + // Check breakpoint d2d + if (dh_0 + dh_1 >= text_length) { + if (breakpoint_forward) { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_0,dh_0); + const int h = WAVEFRONT_H(k_0,dh_0); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = dh_0; + breakpoint->offset_reverse = dh_1; + breakpoint->score = score_0 + score_1 - gap_open; + breakpoint->component = component; + return; + } else { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_1,dh_1); + const int h = WAVEFRONT_H(k_1,dh_1); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = dh_1; + breakpoint->offset_reverse = dh_0; + breakpoint->score = score_0 + score_1 - gap_open; + breakpoint->component = component; + return; + } + } + } + k_0 = initial_k0; + } +} +#else // No AVX512, but AVX2 available +void wavefront_bialign_breakpoint_indel2indel_avx2( + wavefront_aligner_t* const wf_aligner, + const bool breakpoint_forward, + const int score_0, + const int score_1, + wavefront_t* const dwf_0, + wavefront_t* const dwf_1, + const affine2p_matrix_type component, + wf_bialign_breakpoint_t* const breakpoint) { + // Parameters + wavefront_sequences_t* const sequences = &wf_aligner->sequences; + const int text_length = sequences->text_length; + const int pattern_length = sequences->pattern_length; + const int gap_open = + (component==affine2p_matrix_I1 || component==affine2p_matrix_D1) ? + wf_aligner->penalties.gap_opening1 : wf_aligner->penalties.gap_opening2; + if (score_0 + score_1 - gap_open >= breakpoint->score) return; + // Check wavefronts overlapping + const int lo_0 = dwf_0->lo; + const int hi_0 = dwf_0->hi; + const int lo_1 = WAVEFRONT_K_INVERSE(dwf_1->hi,pattern_length,text_length); + const int hi_1 = WAVEFRONT_K_INVERSE(dwf_1->lo,pattern_length,text_length); + if (hi_1 < lo_0 || hi_0 < lo_1) return; + // Compute overlapping interval + const int min_hi = MIN(hi_0,hi_1); + const int max_lo = MAX(lo_0,lo_1); + const int elems_per_register = 8; + const int num_diagonals = min_hi - max_lo + 1; + const int loop_peeling_iters = num_diagonals % elems_per_register; + int k_0; + for (k_0=max_lo;k_0 < max_lo+loop_peeling_iters; k_0++) { + const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); + // Fetch offsets + const wf_offset_t doffset_0 = dwf_0->offsets[k_0]; + const wf_offset_t doffset_1 = dwf_1->offsets[k_1]; + const int dh_0 = WAVEFRONT_H(k_0,doffset_0); + const int dh_1 = WAVEFRONT_H(k_1,doffset_1); + // Check breakpoint d2d + if (dh_0 + dh_1 >= text_length) { + if (breakpoint_forward) { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_0,dh_0); + const int h = WAVEFRONT_H(k_0,dh_0); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = dh_0; + breakpoint->offset_reverse = dh_1; + breakpoint->score = score_0 + score_1 - gap_open; + breakpoint->component = component; + return; + } else { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_1,dh_1); + const int h = WAVEFRONT_H(k_1,dh_1); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = dh_1; + breakpoint->offset_reverse = dh_0; + breakpoint->score = score_0 + score_1 - gap_open; + breakpoint->component = component; + return; + } + } + } + // Finish the remaining iterations in a vectorized manner + const __m256i tlens = _mm256_set1_epi32(text_length-1);//enable change >= to > + const __m256i rev = _mm256_set_epi32(0,1,2,3,4,5,6,7); + for (;k_0<=min_hi;k_0+=elems_per_register) { + const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); + // Fetch offsets + __m256i doffsets_0 = _mm256_lddqu_si256((__m256i*)&dwf_0->offsets[k_0]); + __m256i doffsets_1 = _mm256_lddqu_si256((__m256i*)&dwf_1->offsets[k_1-elems_per_register+1]); + // doffsets_1 are in reverse order, so we need to reverse them + doffsets_1 = _mm256_permutevar8x32_epi32(doffsets_1, rev); + __m256i dh_0_1 = _mm256_add_epi32(doffsets_0, doffsets_1); + __m256i mask = _mm256_cmpgt_epi32(dh_0_1, tlens); + int bp_found_mask = _mm256_movemask_epi8(mask); + if (__builtin_expect(bp_found_mask == 0, 1)) continue; + // A breakpoint has been found! Check in which exact diagonal it is + // This can be done directly from the mask and vector registers, for now, + // it is implemented like the scalar implementation. This only happens + // when a BP is found, so it should not be a bottleneck. + int initial_k0 = k_0; + for (;k_0offsets[k_0]; + const wf_offset_t doffset_1 = dwf_1->offsets[k_1]; + const int dh_0 = WAVEFRONT_H(k_0,doffset_0); + const int dh_1 = WAVEFRONT_H(k_1,doffset_1); + // Check breakpoint d2d + if (dh_0 + dh_1 >= text_length) { + if (breakpoint_forward) { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_0,dh_0); + const int h = WAVEFRONT_H(k_0,dh_0); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = dh_0; + breakpoint->offset_reverse = dh_1; + breakpoint->score = score_0 + score_1 - gap_open; + breakpoint->component = component; + return; + } else { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_1,dh_1); + const int h = WAVEFRONT_H(k_1,dh_1); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = dh_1; + breakpoint->offset_reverse = dh_0; breakpoint->score = score_0 + score_1 - gap_open; breakpoint->component = component; - // wavefront_bialign_debug(breakpoint,-1); // DEBUG - // No need to keep searching return; } } } - // Update ks for the next iteration - ks = _mm512_add_epi32(ks, sixteens); + k_0 = initial_k0; } -#else // Scalar implementation of the bialign_breakpoint_indel2indel +} +#endif // AVX512 +#endif // AVX2 +void wavefront_bialign_breakpoint_indel2indel( + wavefront_aligner_t* const wf_aligner, + const bool breakpoint_forward, + const int score_0, + const int score_1, + wavefront_t* const dwf_0, + wavefront_t* const dwf_1, + const affine2p_matrix_type component, + wf_bialign_breakpoint_t* const breakpoint) { +#if __AVX2__ +#if __AVX512CD__ && __AVX512VL__ + wavefront_bialign_breakpoint_indel2indel_avx512( + wf_aligner, + breakpoint_forward, + score_0, + score_1, + dwf_0, + dwf_1, + component, + breakpoint); +#else + wavefront_bialign_breakpoint_indel2indel_avx2( + wf_aligner, + breakpoint_forward, + score_0, + score_1, + dwf_0, + dwf_1, + component, + breakpoint); +#endif // AVX512 +#else // Scalar implementation // Parameters wavefront_sequences_t* const sequences = &wf_aligner->sequences; const int text_length = sequences->text_length; @@ -328,8 +500,6 @@ void wavefront_bialign_breakpoint_indel2indel( const int gap_open = (component==affine2p_matrix_I1 || component==affine2p_matrix_D1) ? wf_aligner->penalties.gap_opening1 : wf_aligner->penalties.gap_opening2; - - // Check wavefronts overlapping const int lo_0 = dwf_0->lo; const int hi_0 = dwf_0->hi; @@ -339,9 +509,7 @@ void wavefront_bialign_breakpoint_indel2indel( // Compute overlapping interval const int min_hi = MIN(hi_0,hi_1); const int max_lo = MAX(lo_0,lo_1); - if (score_0 + score_1 - gap_open >= breakpoint->score) return; - int k_0; for (k_0=max_lo;k_0<=min_hi;k_0++) { const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); @@ -383,11 +551,12 @@ void wavefront_bialign_breakpoint_indel2indel( // No need to keep searching return; } - } -#endif // AVX512 +#endif // AVX2 } -void wavefront_bialign_breakpoint_m2m( +#if __AVX2__ +#if __AVX512CD__ && __AVX512VL__ +void wavefront_bialign_breakpoint_m2m_avx512( wavefront_aligner_t* const wf_aligner, const bool breakpoint_forward, const int score_0, @@ -395,9 +564,6 @@ void wavefront_bialign_breakpoint_m2m( wavefront_t* const mwf_0, wavefront_t* const mwf_1, wf_bialign_breakpoint_t* const breakpoint) { - // Parameters - if (score_0 + score_1 >= breakpoint->score) return; -#if __AVX2__ && __AVX512CD__ && __AVX512VL__ // AVX512 implementation of the bialign_breakpoint_indel2indel // Parameters wavefront_sequences_t* const sequences = &wf_aligner->sequences; @@ -412,12 +578,10 @@ void wavefront_bialign_breakpoint_m2m( // Compute overlapping interval const int min_hi = MIN(hi_0,hi_1); const int max_lo = MAX(lo_0,lo_1); - const int elems_per_register = 16; int k_0 = max_lo; int num_diagonals = min_hi - max_lo + 1; int loop_peeling_iters = num_diagonals % elems_per_register; - // Scalar pass to peel off the first few iterations, and make the main loop a // multiple of the register size for (;k_0= breakpoint->score) return; + // Parameters + wavefront_sequences_t* const sequences = &wf_aligner->sequences; + const int text_length = sequences->text_length; + const int pattern_length = sequences->pattern_length; + // Check wavefronts overlapping + const int lo_0 = mwf_0->lo; + const int hi_0 = mwf_0->hi; + const int lo_1 = WAVEFRONT_K_INVERSE(mwf_1->hi,pattern_length,text_length); + const int hi_1 = WAVEFRONT_K_INVERSE(mwf_1->lo,pattern_length,text_length); + if (hi_1 < lo_0 || hi_0 < lo_1) return; + // Compute overlapping interval + const int min_hi = MIN(hi_0,hi_1); + const int max_lo = MAX(lo_0,lo_1); + const int elems_per_register = 8; + const int num_diagonals = min_hi - max_lo + 1; + const int loop_peeling_iters = num_diagonals % elems_per_register; + int k_0; + for (k_0=max_lo;k_0 < max_lo+loop_peeling_iters; k_0++) { + const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); + // Fetch offsets + const wf_offset_t moffset_0 = mwf_0->offsets[k_0]; + const wf_offset_t moffset_1 = mwf_1->offsets[k_1]; + const int mh_0 = WAVEFRONT_H(k_0,moffset_0); + const int mh_1 = WAVEFRONT_H(k_1,moffset_1); + // Check breakpoint m2m + if (mh_0 + mh_1 >= text_length) { + if (breakpoint_forward) { + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = moffset_0; + breakpoint->offset_reverse = moffset_1; + } else { + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = moffset_1; + breakpoint->offset_reverse = moffset_0; + } + breakpoint->score = score_0 + score_1; + breakpoint->component = affine2p_matrix_M; + return; + } + } + const __m256i tlens = _mm256_set1_epi32(text_length-1); //enable change >= to > + const __m256i rev = _mm256_set_epi32(0,1,2,3,4,5,6,7); + for (;k_0<=min_hi;k_0+=elems_per_register) { + const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); + // Fetch offsets + __m256i moffsets_0 = _mm256_lddqu_si256((__m256i*)&mwf_0->offsets[k_0]); + __m256i moffsets_1 = _mm256_lddqu_si256((__m256i*)&mwf_1->offsets[k_1-elems_per_register+1]); + // doffsets_1 are in reverse order, so we need to reverse them + moffsets_1 = _mm256_permutevar8x32_epi32(moffsets_1, rev); + __m256i mh_0_1 = _mm256_add_epi32(moffsets_0, moffsets_1); + __m256i mask = _mm256_cmpgt_epi32(mh_0_1, tlens); + int bp_found_mask = _mm256_movemask_epi8(mask); + if (__builtin_expect(bp_found_mask == 0, 1)) continue; + // A breakpoint has been found! Check in which exact diagonal it is + // This can be done directly from the mask and vector registers, for now, + // it is implemented like the scalar implementation. This only happens + // when a BP is found, so it should not be a bottleneck. + int initial_k0 = k_0; + for (;k_0offsets[k_0]; + const wf_offset_t moffset_1 = mwf_1->offsets[k_1]; + const int mh_0 = WAVEFRONT_H(k_0,moffset_0); + const int mh_1 = WAVEFRONT_H(k_1,moffset_1); + // Check breakpoint m2m + if (mh_0 + mh_1 >= text_length) { + if (breakpoint_forward) { + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = moffset_0; + breakpoint->offset_reverse = moffset_1; + } else { + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = moffset_1; + breakpoint->offset_reverse = moffset_0; + } + breakpoint->score = score_0 + score_1; + breakpoint->component = affine2p_matrix_M; + return; + } + } + } +} +#endif // AVX2 +#endif // AVX512 +void wavefront_bialign_breakpoint_m2m( + wavefront_aligner_t* const wf_aligner, + const bool breakpoint_forward, + const int score_0, + const int score_1, + wavefront_t* const mwf_0, + wavefront_t* const mwf_1, + wf_bialign_breakpoint_t* const breakpoint) { + // Parameters + if (score_0 + score_1 >= breakpoint->score) return; +#if __AVX2__ +#if __AVX512CD__ && __AVX512VL__ + wavefront_bialign_breakpoint_m2m_avx512( + wf_aligner, + breakpoint_forward, + score_0, + score_1, + mwf_0, + mwf_1, + breakpoint); +#else + wavefront_bialign_breakpoint_m2m_avx2( + wf_aligner, + breakpoint_forward, + score_0, + score_1, + mwf_0, + mwf_1, + breakpoint); +#endif // AVX2 +#else // Scalar implementation wavefront_sequences_t* const sequences = &wf_aligner->sequences; const int text_length = sequences->text_length; const int pattern_length = sequences->pattern_length; @@ -556,7 +850,7 @@ void wavefront_bialign_breakpoint_m2m( return; } } -#endif // AVX512 +#endif // AVX2 } /* * Bidirectional find overlaps