@@ -283,76 +283,52 @@ inline Vectorized<float> exp_u20(Vectorized<float> data) {
283283#endif
284284
285285// out = val * a + b
286- template <typename T1, typename T2>
286+ // is_b_stride_zero: If the stride of b is 0 (mask broadcasting case),
287+ // take b as a scalar pointer.
288+ template <bool is_b_stride_zero, typename T1, typename T2>
287289inline void _scale_attn_mask_fusion_kernel (
288290 T1* a,
289291 T2* b,
290292 const int & size,
291293 T1* out,
292294 T1& val) {
293- auto vec_size = at::vec::Vectorized<T1>::size ();
294- auto vec_scale = at::vec::Vectorized<T1>(val);
295- for (long i = 0 ; i < vec_size * (size / vec_size); i += vec_size) {
296- auto tmp0 = at::vec::Vectorized<T1>::loadu (a + i);
297- auto tmp1 = at::vec::Vectorized<T2>::loadu (b + i);
298- auto tmp2 = at::vec::convert<T1>(tmp1);
299- auto tmp3 = tmp0 * vec_scale + tmp2;
300- _store (out + i, tmp3);
301- }
302- for (long i = vec_size * (size / vec_size); i < size; i++) {
303- auto tmp0 = a[i];
304- auto tmp1 = (T1)b[i];
305- out[i] = tmp0 * val + tmp1;
306- }
307- }
308-
309- // out = val * a + b
310- template <typename T1>
311- inline void _scale_attn_mask_fusion_kernel (
312- T1* a,
313- T1* b,
314- const int & size,
315- T1* out,
316- T1& val) {
317- auto vec_size = at::vec::Vectorized<T1>::size ();
318- auto vec_scale = at::vec::Vectorized<T1>(val);
319- for (long i = 0 ; i < vec_size * (size / vec_size); i += vec_size) {
320- auto tmp0 = at::vec::Vectorized<T1>::loadu (a + i);
321- auto tmp1 = at::vec::Vectorized<T1>::loadu (b + i);
322- auto tmp2 = tmp0 * vec_scale + tmp1;
323- _store (out + i, tmp2);
324- }
325- for (long i = vec_size * (size / vec_size); i < size; i++) {
326- auto tmp0 = a[i];
327- auto tmp1 = b[i];
328- out[i] = tmp0 * val + tmp1;
329- }
330- }
331-
332- // out = b ? val * a : -inf
333- template <typename T1>
334- inline void _scale_attn_mask_fusion_kernel (
335- T1* a,
336- bool * b,
337- const int & size,
338- T1* out,
339- T1& val) {
340- auto vec_size = at::vec::Vectorized<T1>::size ();
341- auto vec_scale = at::vec::Vectorized<T1>(val);
342- auto neg_inf = -std::numeric_limits<T1>::infinity ();
343- auto vec_neg_inf = at::vec::Vectorized<T1>(neg_inf);
344- for (long i = 0 ; i < vec_size * (size / vec_size); i += vec_size) {
345- auto tmp0 = at::vec::Vectorized<T1>::loadu (a + i);
346- auto tmp1 = at::vec::Vectorized<bool >::loadu (b + i);
347- auto tmp2 = at::vec::convert<T1>(tmp1);
348- auto tmp3 =
349- at::vec::Vectorized<T1>::blendv (vec_neg_inf, tmp0 * vec_scale, tmp2);
350- _store (out + i, tmp3);
351- }
352- for (long i = vec_size * (size / vec_size); i < size; i++) {
353- auto tmp0 = a[i];
354- auto tmp1 = b[i];
355- out[i] = tmp1 ? tmp0 * val : neg_inf;
295+ const auto vec_size1 = at::vec::Vectorized<T1>::size ();
296+ const auto vec_size2 = at::vec::Vectorized<T2>::size ();
297+ constexpr int64_t T1_n =
298+ (vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v<T2>) ? 2 : 1 ;
299+ constexpr int64_t T2_n = 1 ;
300+ auto vec_scale = at::vec::VectorizedN<T1, T1_n>(val);
301+ int64_t i = 0 ;
302+ if (is_b_stride_zero) {
303+ auto b_first_val = (T1)b[0 ];
304+ auto b_first_vec = at::vec::VectorizedN<T2, T2_n>(b_first_val);
305+ for (; i < size - (size % vec_size2); i += vec_size2) {
306+ auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu (a + i);
307+ auto b_n = b_first_vec;
308+ at::vec::VectorizedN<T1, T1_n> b_n_convert =
309+ at::vec::convert<T1, T1_n, T2, T2_n, true >(b_n);
310+ auto res = a_n * vec_scale + b_n_convert;
311+ res.store (out + i);
312+ }
313+ for (; i < size; i++) {
314+ auto tmp0 = a[i];
315+ auto tmp1 = b_first_val;
316+ out[i] = tmp0 * val + tmp1;
317+ }
318+ } else {
319+ for (; i < size - (size % vec_size2); i += vec_size2) {
320+ auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu (a + i);
321+ auto b_n = at::vec::VectorizedN<T2, T2_n>::loadu (b + i);
322+ at::vec::VectorizedN<T1, T1_n> b_n_convert =
323+ at::vec::convert<T1, T1_n, T2, T2_n, true >(b_n);
324+ auto res = a_n * vec_scale + b_n_convert;
325+ res.store (out + i);
326+ }
327+ for (; i < size; i++) {
328+ auto tmp0 = a[i];
329+ auto tmp1 = (T1)b[i];
330+ out[i] = tmp0 * val + tmp1;
331+ }
356332 }
357333}
358334
@@ -425,6 +401,82 @@ inline void _mul_reduce_max_fusion_kernel(
425401 vec_tmp_max));
426402}
427403
404+ // This function is used to produce an attn_mask in a standard format
405+ inline std::optional<at::Tensor> convert_boolean_attn_mask (
406+ const std::optional<at::Tensor>& attn_mask,
407+ caffe2::TypeMeta dtype) {
408+ // Pass through
409+ if (!attn_mask.has_value ()) {
410+ return c10::nullopt ;
411+ }
412+ // Convert boolean mask to additive mask
413+ if (attn_mask->dtype () == at::kBool ) {
414+ auto new_attn_mask = at::zeros_like (attn_mask.value (), dtype);
415+ new_attn_mask.masked_fill_ (
416+ attn_mask->logical_not (), -std::numeric_limits<double >::infinity ());
417+ return new_attn_mask;
418+ }
419+ // Otherwise, attn_mask represents an additive attention tensor
420+ return attn_mask;
421+ }
422+
423+ // Support mask shapes:
424+ // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1})
425+ // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})
426+ inline bool check_attn_mask_shape (
427+ at::Tensor& attn_mask,
428+ int64_t batchSize,
429+ int64_t num_head,
430+ int64_t qSize,
431+ int64_t kvSize) {
432+ if (attn_mask.size (-2 ) != qSize && attn_mask.size (-2 ) != 1 ) {
433+ return false ;
434+ }
435+ if (attn_mask.size (-1 ) != kvSize && attn_mask.size (-1 ) != 1 ) {
436+ return false ;
437+ }
438+ if (attn_mask.dim () == 2 ) {
439+ return true ;
440+ } else if (attn_mask.dim () == 4 ) {
441+ if ((attn_mask.size (0 ) == 1 || attn_mask.size (0 ) == batchSize) &&
442+ (attn_mask.size (1 ) == 1 || attn_mask.size (1 ) == num_head)) {
443+ return true ;
444+ }
445+ }
446+ return false ;
447+ }
448+
449+ // Reshape attention mask to 4d
450+ inline void reshape_attn_mask_to_4d (
451+ at::Tensor& attn_mask,
452+ int64_t batchSize,
453+ int64_t num_head,
454+ int64_t qSize,
455+ int64_t kvSize) {
456+ TORCH_CHECK (
457+ check_attn_mask_shape (attn_mask, batchSize, num_head, qSize, kvSize),
458+ " IPEX flash_attention: Please use the following attn mask shapes: " ,
459+ " 2d - ({Q_seq_len, 1} x {KV_seq_len, 1}); " ,
460+ " 4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})" );
461+ int64_t attn_mask_size_0 = 1 ;
462+ int64_t attn_mask_size_1 = 1 ;
463+ if (attn_mask.dim () == 4 ) {
464+ if (attn_mask.size (0 ) == batchSize) {
465+ attn_mask_size_0 = batchSize;
466+ }
467+ if (attn_mask.size (1 ) == num_head) {
468+ attn_mask_size_1 = num_head;
469+ }
470+ }
471+ attn_mask = attn_mask
472+ .view (
473+ {attn_mask_size_0,
474+ attn_mask_size_1,
475+ attn_mask.size (-2 ),
476+ attn_mask.size (-1 )})
477+ .expand ({attn_mask_size_0, attn_mask_size_1, qSize, kvSize});
478+ }
479+
428480/*
429481 *Caculate the flash attention SDPA.
430482 *@template scalar_t: q/k/v data type
@@ -480,6 +532,12 @@ cpu_flash_attention(
480532 int64_t num_head = query.size (2 );
481533 int64_t headSize = query.size (3 );
482534
535+ // reshape mask
536+ if (attention_mask.has_value ()) {
537+ reshape_attn_mask_to_4d (
538+ attention_mask.value (), batchSize, num_head, qSize, kvSize);
539+ }
540+
483541 // Strides
484542 int64_t qStrideB = query.stride (0 );
485543 int64_t qStrideM = query.stride (1 );
@@ -505,7 +563,13 @@ cpu_flash_attention(
505563 ? attention_mask.value ().stride (1 )
506564 : 0 ;
507565 int64_t mStrideM =
508- attention_mask.has_value () ? attention_mask.value ().stride (2 ) : 0 ;
566+ (attention_mask.has_value () && attention_mask.value ().size (2 ) > 1 )
567+ ? attention_mask.value ().stride (2 )
568+ : 0 ;
569+ int64_t mStrideN =
570+ (attention_mask.has_value () && attention_mask.value ().size (3 ) > 1 )
571+ ? attention_mask.value ().stride (3 )
572+ : 0 ;
509573
510574 int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
511575 int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
@@ -596,15 +660,24 @@ cpu_flash_attention(
596660 // And apply scaling factor
597661 if (attention_mask.has_value ()) {
598662 for (int64_t row = 0 ; row < qBlockSize; ++row) {
599- // qk <- attn_mask ? qk : -inf, if attn_mask is bool
600- // qk <- qk + attn_mask, else
601- _scale_attn_mask_fusion_kernel (
602- qk_data + row * kvBlockSize,
603- mask_data + i * mStrideB + j * mStrideH +
604- (m + row) * mStrideM + n,
605- kvBlockSize,
606- qk_data + row * kvBlockSize,
607- scaling_factor);
663+ // qk <- qk * scaling_factor + attn_mask, else
664+ if (mStrideN == 0 ) {
665+ _scale_attn_mask_fusion_kernel</* is_stride_zero*/ true >(
666+ qk_data + row * kvBlockSize,
667+ mask_data + i * mStrideB + j * mStrideH +
668+ (m + row) * mStrideM ,
669+ kvBlockSize,
670+ qk_data + row * kvBlockSize,
671+ scaling_factor);
672+ } else {
673+ _scale_attn_mask_fusion_kernel</* is_stride_zero*/ false >(
674+ qk_data + row * kvBlockSize,
675+ mask_data + i * mStrideB + j * mStrideH +
676+ (m + row) * mStrideM + n,
677+ kvBlockSize,
678+ qk_data + row * kvBlockSize,
679+ scaling_factor);
680+ }
608681 }
609682 }
610683 // Update coefficients with Softmax
@@ -737,6 +810,12 @@ cpu_flash_attention(
737810 int64_t num_head = query.size (2 );
738811 int64_t headSize = query.size (3 );
739812
813+ // reshape mask
814+ if (attention_mask.has_value ()) {
815+ reshape_attn_mask_to_4d (
816+ attention_mask.value (), batchSize, num_head, qSize, kvSize);
817+ }
818+
740819 // Strides
741820 int64_t qStrideB = query.stride (0 );
742821 int64_t qStrideM = query.stride (1 );
@@ -762,7 +841,13 @@ cpu_flash_attention(
762841 ? attention_mask.value ().stride (1 )
763842 : 0 ;
764843 int64_t mStrideM =
765- attention_mask.has_value () ? attention_mask.value ().stride (2 ) : 0 ;
844+ (attention_mask.has_value () && attention_mask.value ().size (2 ) > 1 )
845+ ? attention_mask.value ().stride (2 )
846+ : 0 ;
847+ int64_t mStrideN =
848+ (attention_mask.has_value () && attention_mask.value ().size (3 ) > 1 )
849+ ? attention_mask.value ().stride (3 )
850+ : 0 ;
766851
767852 int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
768853 int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
@@ -1241,15 +1326,24 @@ cpu_flash_attention(
12411326 // And apply scaling factor
12421327 if (attention_mask.has_value ()) {
12431328 for (int64_t row = 0 ; row < qBlockSize; ++row) {
1244- // qk <- attn_mask ? qk : -inf, if attn_mask is bool
1245- // qk <- qk + attn_mask, else
1246- _scale_attn_mask_fusion_kernel (
1247- qk_data + row * kvBlockSize,
1248- mask_data + i * mStrideB + j * mStrideH +
1249- (m + row) * mStrideM + n,
1250- kvBlockSize,
1251- qk_data + row * kvBlockSize,
1252- scaling_factor);
1329+ // qk <- qk * scaling_factor + attn_mask, else
1330+ if (mStrideN == 0 ) {
1331+ _scale_attn_mask_fusion_kernel</* is_stride_zero*/ true >(
1332+ qk_data + row * kvBlockSize,
1333+ mask_data + i * mStrideB + j * mStrideH +
1334+ (m + row) * mStrideM ,
1335+ kvBlockSize,
1336+ qk_data + row * kvBlockSize,
1337+ scaling_factor);
1338+ } else {
1339+ _scale_attn_mask_fusion_kernel</* is_stride_zero*/ false >(
1340+ qk_data + row * kvBlockSize,
1341+ mask_data + i * mStrideB + j * mStrideH +
1342+ (m + row) * mStrideM + n,
1343+ kvBlockSize,
1344+ qk_data + row * kvBlockSize,
1345+ scaling_factor);
1346+ }
12531347 }
12541348 }
12551349 // Update coefficients with Softmax
@@ -1558,6 +1652,8 @@ std::tuple<at::Tensor, at::Tensor> flash_attention_kernel(
15581652 attention_mask.value ().stride (-1 ) == 1 ),
15591653 " IPEX flash_attention: Q/K/V/Mask should be continuous on the last dim" );
15601654
1655+ std::optional<at::Tensor> attn_mask =
1656+ convert_boolean_attn_mask (attention_mask, query.dtype ());
15611657 at::Tensor output =
15621658 at::empty ({batchSize, qSize, num_head, headSize}, query.options ());
15631659 const auto accumulate_dtype = at::toOpMathType (dtype);
@@ -1572,7 +1668,7 @@ std::tuple<at::Tensor, at::Tensor> flash_attention_kernel(
15721668 value,
15731669 dropout_p,
15741670 is_causal,
1575- attention_mask ,
1671+ attn_mask ,
15761672 scale);
15771673
15781674 output = output.transpose (1 , 2 );
0 commit comments