@@ -525,6 +525,51 @@ inline std::array<__m256i, 2> load_sint4_as_int8(uint8_t* qB) {
525525 return {low, high};
526526}
527527
528+ // load nf4
529+ inline std::array<__m256i, 2 > load_nf4_as_int8 (uint8_t * qB) {
530+ __m256i packed = _mm256_loadu_si256 (reinterpret_cast <const __m256i*>(qB));
531+ const __m256i low_mask = _mm256_set1_epi8 (0x0f );
532+ __m256i high = _mm256_srli_epi16 (packed, 4 );
533+ high = _mm256_and_si256 (high, low_mask);
534+ __m256i low = _mm256_and_si256 (packed, low_mask);
535+ const __m256i lut = _mm256_set_epi8 (
536+ 0 ,
537+ 0 ,
538+ 0 ,
539+ 0 ,
540+ 0 ,
541+ 0 ,
542+ 0 ,
543+ 0 ,
544+ 0 ,
545+ 0 ,
546+ 0 ,
547+ 0 ,
548+ 0 ,
549+ 0 ,
550+ 0 ,
551+ 0 ,
552+ 127 ,
553+ 92 ,
554+ 71 ,
555+ 56 ,
556+ 43 ,
557+ 31 ,
558+ 20 ,
559+ 10 ,
560+ 0 ,
561+ -12 ,
562+ -23 ,
563+ -36 ,
564+ -50 ,
565+ -67 ,
566+ -88 ,
567+ -127 );
568+ low = _mm256_permutexvar_epi8 (low, lut);
569+ high = _mm256_permutexvar_epi8 (high, lut);
570+ return {low, high};
571+ }
572+
528573#else
529574inline std::array<__m256i, 2 > load_zps_4vnni (int8_t * zps) {
530575 TLA_ASSERT (false , " not implemented" );
@@ -541,6 +586,11 @@ inline std::array<__m256i, 2> load_sint4_as_int8(uint8_t* qB) {
541586 return std::array<__m256i, 2 >();
542587}
543588
589+ inline std::array<__m256i, 2 > load_nf4_as_int8 (uint8_t * qB) {
590+ TLA_ASSERT (false , " not implemented" );
591+ return std::array<__m256i, 2 >();
592+ }
593+
544594#endif
545595
546596template <long N, bool sym_quant, typename T>
@@ -831,6 +881,10 @@ struct GemmMicroKernel<
831881 // Load scales and zps
832882 compile_time_for<COLS>::op ([&](auto i) {
833883 vscales[i] = _mm512_loadu_ps (scales + i * 16 );
884+ if constexpr (qw_type == WOQ_DTYPE_NF4) {
885+ const __m512 factor = _mm512_set1_ps (1 .0f / 127 .0f );
886+ vscales[i] = _mm512_mul_ps (vscales[i], factor);
887+ }
834888 // TODO(jgong5): should we use 512 or two 256 here?
835889 if constexpr (!sym_quant_w) {
836890 vzps[i] = combine_m256i (load_zps_4vnni (zps + i * 16 ));
@@ -859,8 +913,10 @@ struct GemmMicroKernel<
859913 if constexpr (!sym_quant_w) {
860914 vb[col] = combine_m256i (load_uint4_as_int8 (pqB[k / 4 ][col * 16 ]));
861915 vb[col] = _mm512_sub_epi8 (vb[col], vzps[col]);
862- } else {
916+ } else if constexpr (qw_type == WOQ_DTYPE_INT4) {
863917 vb[col] = combine_m256i (load_sint4_as_int8 (pqB[k / 4 ][col * 16 ]));
918+ } else {
919+ vb[col] = combine_m256i (load_nf4_as_int8 (pqB[k / 4 ][col * 16 ]));
864920 }
865921 if constexpr (is_asymmetric_quant_a (quant_a_mode)) {
866922 vcompensate[col] =
@@ -1290,12 +1346,12 @@ struct Dequantize<half, ldb, N_GROUP_SIZE, qw_type, sym_quant_w> {
12901346 }
12911347};
12921348
1293- template <long ldb, bool sym_quant_w>
1349+ template <long ldb, int qw_type, bool sym_quant_w>
12941350struct Dequantize <
12951351 int8_t ,
12961352 ldb,
12971353 /* N_GROUP_SIZE*/ 16 ,
1298- /* qw_type*/ WOQ_DTYPE_INT4 ,
1354+ qw_type,
12991355 sym_quant_w> {
13001356 template <int quant_a_mode>
13011357 static inline void call (
@@ -1330,10 +1386,14 @@ struct Dequantize<
13301386 auto [low, high] = load_uint4_as_int8 (pqB[k][n]);
13311387 vb_high = _mm256_sub_epi8 (high, vzps_high);
13321388 vb_low = _mm256_sub_epi8 (low, vzps_low);
1333- } else {
1389+ } else if constexpr (qw_type == WOQ_DTYPE_INT4) {
13341390 auto [low, high] = load_sint4_as_int8 (pqB[k][n]);
13351391 vb_low = low;
13361392 vb_high = high;
1393+ } else {
1394+ auto [low, high] = load_nf4_as_int8 (pqB[k][n]);
1395+ vb_low = low;
1396+ vb_high = high;
13371397 }
13381398 if constexpr (is_asymmetric_quant_a (quant_a_mode)) {
13391399 vcompensate[0 ] = _mm256_dpbusd_epi32 (vcompensate[0 ], ones, vb_low);
@@ -1585,6 +1645,7 @@ template <
15851645 long ldb,
15861646 bool transA,
15871647 bool ACC,
1648+ int qw_type,
15881649 int quant_a_mode,
15891650 int quant_w_mode,
15901651 long PREFETCH_K_DIST>
@@ -1598,7 +1659,7 @@ class DequantGemmTPP<
15981659 ldb,
15991660 transA,
16001661 ACC,
1601- /* qw_type*/ WOQ_DTYPE_INT4 ,
1662+ qw_type,
16021663 quant_a_mode,
16031664 quant_w_mode,
16041665 PREFETCH_K_DIST> {
@@ -1696,7 +1757,7 @@ class DequantGemmTPP<
16961757 ACC,
16971758 quant_a_mode,
16981759 PREFETCH_K_DIST>::
1699- template call<WOQ_DTYPE_INT4 , sym_quant_w>(
1760+ template call<qw_type , sym_quant_w>(
17001761 K,
17011762 qA[m],
17021763 lda,
@@ -1725,7 +1786,7 @@ class DequantGemmTPP<
17251786 ACC,
17261787 quant_a_mode,
17271788 PREFETCH_K_DIST>::
1728- template call<WOQ_DTYPE_INT4 , sym_quant_w>(
1789+ template call<qw_type , sym_quant_w>(
17291790 K,
17301791 qA[m],
17311792 lda,
@@ -1748,12 +1809,7 @@ class DequantGemmTPP<
17481809 int8_t B[K / 4 ][N][4 ];
17491810 int32_t qC[M][N];
17501811 int32_t compensation[N];
1751- Dequantize<
1752- int8_t ,
1753- ldb,
1754- N_GROUP_SIZE,
1755- /* qw_type*/ WOQ_DTYPE_INT4,
1756- sym_quant_w>::
1812+ Dequantize<int8_t , ldb, N_GROUP_SIZE, qw_type, sym_quant_w>::
17571813 template call<quant_a_mode>(qB, K, N, zps, B[0 ][0 ], compensation);
17581814 (*pgemm)((int8_t *)qA[0 ], B[0 ][0 ], qC[0 ], 1 , no_tile_cfg);
17591815 if constexpr (PREFETCH_K_DIST > 0 ) {
@@ -1782,11 +1838,14 @@ class DequantGemmTPP<
17821838 }
17831839 }
17841840 float c = 0 ;
1841+ auto scale = scales[n];
1842+ if constexpr (qw_type == WOQ_DTYPE_NF4) {
1843+ scale *= (1 .0f / 127 .0f );
1844+ }
17851845 if constexpr (is_asymmetric_quant_a (quant_a_mode)) {
1786- c = (qC[m][n] - compensation[n] * (*zp_a_m)) * (*scale_a_m) *
1787- scales[n];
1846+ c = (qC[m][n] - compensation[n] * (*zp_a_m)) * (*scale_a_m) * scale;
17881847 } else {
1789- c = (qC[m][n]) * (*scale_a_m) * scales[n] ;
1848+ c = (qC[m][n]) * (*scale_a_m) * scale ;
17901849 }
17911850 if constexpr (ACC) {
17921851 C[m * ldc + n] += c;
0 commit comments