|
4 | 4 | #include <torch/all.h> |
5 | 5 | #include <torch/csrc/autograd/function.h> |
6 | 6 | #include <limits> |
7 | | -#include "vec/vec.h" |
8 | 7 | #include "../../utils/isa_utils.h" |
| 8 | +#include "vec/vec.h" |
9 | 9 |
|
10 | 10 | namespace torch_ipex { |
11 | 11 | namespace cpu { |
@@ -1346,7 +1346,8 @@ first_token_masked_mha( |
1346 | 1346 | auto attn_outputs = at::Tensor(); |
1347 | 1347 | auto attn_weights = at::Tensor(); |
1348 | 1348 | if ((key.scalar_type() == at::kFloat || key.scalar_type() == at::kBFloat16 || |
1349 | | - (key.scalar_type() == at::kHalf && utils::isa_has_avx512_fp16_support())) && |
| 1349 | + (key.scalar_type() == at::kHalf && |
| 1350 | + utils::isa_has_avx512_fp16_support())) && |
1350 | 1351 | attention_mask.stride(-1) == 1) { |
1351 | 1352 | query = query.transpose(1, 2); |
1352 | 1353 | key = key.transpose(1, 2); |
@@ -1447,27 +1448,26 @@ masked_multihead_self_attention_kernel_impl( |
1447 | 1448 | query.size(0); // record the promt bs info |
1448 | 1449 |
|
1449 | 1450 | } else if (offset > 0 && offset + cur_len > cache_size) { |
1450 | | - auto new_cache_size = cache_size * 2 + 2; |
| 1451 | + auto new_cache_size = cache_size * 2; |
1451 | 1452 | auto new_key_cache = at::empty( |
1452 | 1453 | {new_cache_size, beam_batch, key.size(2), key.size(3)}, key.options()); |
1453 | 1454 | auto new_value_cache = at::empty( |
1454 | 1455 | {new_cache_size, beam_batch, value.size(2), value.size(3)}, |
1455 | 1456 | value.options()); |
1456 | 1457 | auto new_beam_idx = |
1457 | | - at::zeros({new_cache_size, beam_batch}, beam_idx.options()); |
| 1458 | + at::zeros({new_cache_size + 2, beam_batch}, beam_idx.options()); |
1458 | 1459 | new_key_cache.slice(0, 0, cache_size).copy_(key_cache); |
1459 | 1460 | new_value_cache.slice(0, 0, cache_size).copy_(value_cache); |
1460 | | - new_beam_idx.slice(0, 0, cache_size).copy_(beam_idx); |
| 1461 | + new_beam_idx.slice(0, 0, cache_size + 2).copy_(beam_idx); |
1461 | 1462 | auto new_beam_idx_access = new_beam_idx.accessor<long, 2>(); |
1462 | 1463 | auto beam_idx_access = beam_idx.accessor<long, 2>(); |
1463 | 1464 | for (auto i = offset; i < new_cache_size; i++) { |
1464 | 1465 | for (auto j = 0; j < beam_batch; j++) { |
1465 | 1466 | new_beam_idx_access[i][j] = beam_idx_access[0][j]; |
1466 | 1467 | } |
1467 | 1468 | } |
1468 | | - new_beam_idx_access[new_cache_size - 2][0] = |
1469 | | - beam_idx_access[cache_size - 2][0]; |
1470 | | - new_beam_idx_access[new_cache_size - 1][0] = |
| 1469 | + new_beam_idx_access[new_cache_size][0] = beam_idx_access[cache_size - 2][0]; |
| 1470 | + new_beam_idx_access[new_cache_size + 1][0] = |
1471 | 1471 | beam_idx_access[cache_size - 1][0]; |
1472 | 1472 | key_cache = new_key_cache; |
1473 | 1473 | value_cache = new_value_cache; |
|
0 commit comments