@@ -145,6 +145,28 @@ void reduce_head_half(
145145}
146146#endif
147147
148+ template <typename T>
149+ void reduce_head (
150+ const T* q_ptr_start,
151+ int64_t kv_head_group_size,
152+ const T* k_ptr_start,
153+ float * attn_w_pos,
154+ int attn_w_stride,
155+ int64_t head_size,
156+ bool store_key,
157+ T* k_cache_start) {
158+ for (auto i = 0 ; i < kv_head_group_size; i++) {
159+ attn_w_pos[i * attn_w_stride] = 0 ;
160+ reduce_head<T>(
161+ q_ptr_start + i * head_size,
162+ k_ptr_start,
163+ attn_w_pos + i * attn_w_stride,
164+ head_size,
165+ store_key,
166+ k_cache_start);
167+ }
168+ }
169+
148170/*
149171 *reduce the attention_weights with the value embedding by the dimension of
150172 *head_size for every head
@@ -170,6 +192,32 @@ void mul_attenion_weights_and_value_of_head(
170192 }
171193}
172194
195+ template <typename T, typename T1>
196+ void mul_attenion_weights_and_value_of_head (
197+ float * attn_w,
198+ int attn_w_stride,
199+ const T* v_ptr_start,
200+ T1* attn_out_start,
201+ int attn_out_strideH,
202+ int kv_head_group_size,
203+ int64_t head_size,
204+ bool store_value,
205+ T* v_cache_start,
206+ uint8_t * flag_access) {
207+ for (auto i = 0 ; i < kv_head_group_size; i++) {
208+ mul_attenion_weights_and_value_of_head<T, T1>(
209+ attn_w[i * attn_w_stride],
210+ v_ptr_start,
211+ attn_out_start + i * attn_out_strideH,
212+ head_size,
213+ store_value,
214+ v_cache_start,
215+ flag_access[i]);
216+ if (flag_access[i] == 0 )
217+ flag_access[i] = 1 ;
218+ }
219+ }
220+
173221#if defined(CPU_CAPABILITY_AVX512)
174222template <>
175223void mul_attenion_weights_and_value_of_head (
@@ -594,17 +642,21 @@ scale_dot_product_for_indirect_access_kv_cache(
594642#pragma omp parallel for collapse(3)
595643 for (auto block_id = 0 ; block_id < kv_block_count; block_id++) {
596644 for (auto bi = 0 ; bi < bs; bi++) {
597- for (auto hi = 0 ; hi < head_num; hi++) {
645+ for (auto head_group_start = 0 ; head_group_start < head_num;
646+ head_group_start += group_size) {
598647 auto k_start = block_id * kv_block_size;
599648 auto block_size = std::min (kv_block_size, seq_len - k_start);
600649 auto query_ti = 0 ;
601650 for (auto ti = k_start; ti < k_start + block_size; ti++) {
602- auto kv_hi = hi / group_size; // maping the query head to
603- // key/value head to support MGA/MQA
651+ auto kv_hi = head_group_start /
652+ group_size; // maping the query head to
653+ // key/value head to support MGA/MQA
604654 auto q_ptr_start = q_ptr +
605655 (bi * cur_len + query_ti) * head_num * head_size +
606- hi * head_size;
607- auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len;
656+ head_group_start * head_size;
657+ auto attn_w_stride2 = cur_len * seq_len;
658+ auto attn_w_stride =
659+ (bi * head_num + head_group_start) * attn_w_stride2;
608660 auto attn_w_pos =
609661 attn_w_ptr + attn_w_stride + query_ti * seq_len + ti;
610662 attn_w_pos[0 ] = 0 .0f ;
@@ -632,8 +684,10 @@ scale_dot_product_for_indirect_access_kv_cache(
632684 kv_hi * head_size;
633685 reduce_head<QT>(
634686 q_ptr_start,
687+ group_size,
635688 k_ptr_start,
636689 attn_w_pos,
690+ attn_w_stride2,
637691 head_size,
638692 true ,
639693 kc_head_start);
@@ -644,8 +698,10 @@ scale_dot_product_for_indirect_access_kv_cache(
644698 kv_hi * head_size;
645699 reduce_head<QT>(
646700 q_ptr_start,
701+ group_size,
647702 k_ptr_start,
648703 attn_w_pos,
704+ attn_w_stride2,
649705 head_size,
650706 false ,
651707 nullptr );
@@ -662,8 +718,10 @@ scale_dot_product_for_indirect_access_kv_cache(
662718 k_cache_ptr + kc_t_beam_start + kv_hi * head_size;
663719 reduce_head<QT>(
664720 q_ptr_start,
721+ group_size,
665722 kc_head_start,
666723 attn_w_pos,
724+ attn_w_stride2,
667725 head_size,
668726 false ,
669727 nullptr );
@@ -737,6 +795,7 @@ scale_dot_product_for_indirect_access_kv_cache(
737795 auto private_attn_out_flag =
738796 at::zeros ({thread_numbers, bs, head_num}, at::kByte );
739797 auto flag_access = private_attn_out_flag.accessor <uint8_t , 3 >();
798+ uint8_t * flag_access_ptr = flag_access.data ();
740799 auto private_attn_out_ptr = private_attn_outs.data_ptr <float >();
741800 // private_attn_outs.numel());
742801 auto attn_outs_stride_priv = bs * head_num * cur_len * head_size;
@@ -747,7 +806,7 @@ scale_dot_product_for_indirect_access_kv_cache(
747806#pragma omp parallel for collapse(3)
748807 for (auto block_id = 0 ; block_id < kv_block_count; block_id++) {
749808 for (auto bi = 0 ; bi < bs; bi++) {
750- for (auto hi = 0 ; hi < head_num; hi++ ) {
809+ for (auto hi = 0 ; hi < head_num; hi += group_size ) {
751810 auto thread_id = 0 ;
752811 if (kv_block_size < seq_len)
753812 thread_id = omp_get_thread_num ();
@@ -757,15 +816,19 @@ scale_dot_product_for_indirect_access_kv_cache(
757816 for (auto vi = v_start; vi < v_start + block_size; vi++) {
758817 auto kv_hi = hi / group_size; // maping the query head to
759818 // key/value head to support MGA/MQA
760- auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len;
819+ auto attn_w_stride2 = cur_len * seq_len;
820+ auto attn_w_stride = (bi * head_num + hi) * attn_w_stride2;
761821 auto attn_w_query_start =
762- attn_w_ptr + attn_w_stride + query_ti * seq_len;
822+ attn_w_ptr + attn_w_stride + query_ti * seq_len + vi ;
763823 // calculate weighted value and store the result to attn_outs[bs,
764824 // head_num, cur_len, head_size]
825+ auto attn_out_head_stride2 = cur_len * head_size;
765826 auto attn_out_head_stride = thread_id * attn_outs_stride_priv +
766- (bi * head_num + hi) * cur_len * head_size ;
827+ (bi * head_num + hi) * attn_out_head_stride2 ;
767828 auto attn_out_start = private_attn_out_ptr + attn_out_head_stride +
768829 query_ti * head_size;
830+ auto flag_access_start = flag_access_ptr +
831+ head_num * bs * thread_id + head_num * bi + hi;
769832
770833 auto vc_token_start = vi * kc_token_stride;
771834 auto beam = need_update_beam_idx ? new_beam_idx[bi][vi] : 0 ;
@@ -787,13 +850,16 @@ scale_dot_product_for_indirect_access_kv_cache(
787850 (bi * cur_len + vi - offset) * kv_head * head_size +
788851 kv_hi * head_size;
789852 mul_attenion_weights_and_value_of_head<VT, float >(
790- attn_w_query_start[vi],
853+ attn_w_query_start,
854+ attn_w_stride2,
791855 v_ptr_start,
792856 attn_out_start,
857+ attn_out_head_stride2,
858+ group_size,
793859 head_size,
794860 true ,
795861 v_cache_head_start,
796- flag_access[thread_id][bi][hi] );
862+ flag_access_start );
797863 } else if (vi < query_ti + offset) { // caculate attention
798864 // values for the past
799865 // token
@@ -802,13 +868,16 @@ scale_dot_product_for_indirect_access_kv_cache(
802868 (bi * cur_len + vi - offset) * kv_head * head_size +
803869 kv_hi * head_size;
804870 mul_attenion_weights_and_value_of_head<VT, float >(
805- attn_w_query_start[vi],
871+ attn_w_query_start,
872+ attn_w_stride2,
806873 v_ptr_start,
807874 attn_out_start,
875+ attn_out_head_stride2,
876+ group_size,
808877 head_size,
809878 false ,
810879 nullptr ,
811- flag_access[thread_id][bi][hi] );
880+ flag_access_start );
812881 } else {
813882 auto vc_t_beam_start =
814883 vc_token_start + beam * kv_head * head_size;
@@ -822,13 +891,16 @@ scale_dot_product_for_indirect_access_kv_cache(
822891 auto v_cache_head_start =
823892 v_cache_ptr + vc_t_beam_start + kv_hi * head_size;
824893 mul_attenion_weights_and_value_of_head<VT, float >(
825- attn_w_query_start[vi],
894+ attn_w_query_start,
895+ attn_w_stride2,
826896 v_cache_head_start,
827897 attn_out_start,
898+ attn_out_head_stride2,
899+ group_size,
828900 head_size,
829901 false ,
830902 nullptr ,
831- flag_access[thread_id][bi][hi] );
903+ flag_access_start );
832904 }
833905 }
834906 if (flag_access[thread_id][bi][hi] == 0 )
0 commit comments