@@ -77,6 +77,11 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
7777 }
7878 return 1 /iscale ;
7979 }
80+ bool return_early = false;
81+ if (rmse_type < 0 ) {
82+ rmse_type = - rmse_type ;
83+ return_early = true;
84+ }
8085 int weight_type = rmse_type %2 ;
8186 float sumlx = 0 ;
8287 float suml2 = 0 ;
@@ -89,56 +94,9 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
8994 suml2 += w * l * l ;
9095 }
9196 float scale = sumlx /suml2 ;
97+ if (return_early ) return suml2 > 0 ? 0.5f * (scale + 1 /iscale ) : 1 /iscale ;
9298 float best = scale * sumlx ;
93- for (int itry = 0 ; itry < 3 ; ++ itry ) {
94- iscale = 1 /scale ;
95- float slx = 0 ;
96- float sl2 = 0 ;
97- bool changed = false;
98- for (int i = 0 ; i < n ; ++ i ) {
99- int l = nearest_int (iscale * x [i ]);
100- l = MAX (- nmax , MIN (nmax - 1 , l ));
101- if (l + nmax != L [i ]) { changed = true; }
102- float w = weight_type == 1 ? x [i ] * x [i ] : 1.f ;
103- slx += w * x [i ]* l ;
104- sl2 += w * l * l ;
105- }
106- if (!changed || sl2 == 0 || slx * slx <= best * sl2 ) { break ; }
107- for (int i = 0 ; i < n ; ++ i ) {
108- int l = nearest_int (iscale * x [i ]);
109- L [i ] = nmax + MAX (- nmax , MIN (nmax - 1 , l ));
110- }
111- sumlx = slx ; suml2 = sl2 ;
112- scale = sumlx /suml2 ;
113- best = scale * sumlx ;
114- }
115- for (int itry = 0 ; itry < 5 ; ++ itry ) {
116- int n_changed = 0 ;
117- for (int i = 0 ; i < n ; ++ i ) {
118- float w = weight_type == 1 ? x [i ]* x [i ] : 1 ;
119- int l = L [i ] - nmax ;
120- float slx = sumlx - w * x [i ]* l ;
121- if (slx > 0 ) {
122- float sl2 = suml2 - w * l * l ;
123- int new_l = nearest_int (x [i ] * sl2 / slx );
124- new_l = MAX (- nmax , MIN (nmax - 1 , new_l ));
125- if (new_l != l ) {
126- slx += w * x [i ]* new_l ;
127- sl2 += w * new_l * new_l ;
128- if (sl2 > 0 && slx * slx * suml2 > sumlx * sumlx * sl2 ) {
129- L [i ] = nmax + new_l ; sumlx = slx ; suml2 = sl2 ;
130- scale = sumlx / suml2 ; best = scale * sumlx ;
131- ++ n_changed ;
132- }
133- }
134- }
135- }
136- if (!n_changed ) { break ; }
137- }
138- if (rmse_type < 3 ) {
139- return scale ;
140- }
141- for (int is = -4 ; is <= 4 ; ++ is ) {
99+ for (int is = -9 ; is <= 9 ; ++ is ) {
142100 if (is == 0 ) {
143101 continue ;
144102 }
@@ -221,12 +179,17 @@ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t *
221179 return 1 /iscale ;
222180}
223181
224- static float make_qkx1_quants (int n , int nmax , const float * restrict x , uint8_t * restrict L , float * restrict the_min , int ntry ) {
182+ static float make_qkx1_quants (int n , int nmax , const float * restrict x , uint8_t * restrict L , float * restrict the_min ,
183+ int ntry , float alpha ) {
225184 float min = x [0 ];
226185 float max = x [0 ];
186+ float sum_x = 0 ;
187+ float sum_x2 = 0 ;
227188 for (int i = 1 ; i < n ; ++ i ) {
228189 if (x [i ] < min ) min = x [i ];
229190 if (x [i ] > max ) max = x [i ];
191+ sum_x += x [i ];
192+ sum_x2 += x [i ]* x [i ];
230193 }
231194 if (max == min ) {
232195 for (int i = 0 ; i < n ; ++ i ) L [i ] = 0 ;
@@ -254,7 +217,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
254217 for (int i = 0 ; i < n ; ++ i ) {
255218 sum += x [i ] - scale * L [i ];
256219 }
257- min = sum /n ;
220+ min = alpha * min + ( 1 - alpha ) * sum /n ;
258221 if (min > 0 ) min = 0 ;
259222 iscale = 1 /scale ;
260223 if (!did_change ) break ;
@@ -263,6 +226,82 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
263226 return scale ;
264227}
265228
229+ static float make_qkx2_quants (int n , int nmax , const float * restrict x , const float * restrict weights ,
230+ uint8_t * restrict L , float * restrict the_min , uint8_t * restrict Laux ,
231+ float rmin , float rdelta , int nstep , bool use_mad ) {
232+ float min = x [0 ];
233+ float max = x [0 ];
234+ float sum_w = weights [0 ];
235+ float sum_x = sum_w * x [0 ];
236+ for (int i = 1 ; i < n ; ++ i ) {
237+ if (x [i ] < min ) min = x [i ];
238+ if (x [i ] > max ) max = x [i ];
239+ float w = weights [i ];
240+ sum_w += w ;
241+ sum_x += w * x [i ];
242+ }
243+ if (min > 0 ) min = 0 ;
244+ if (max == min ) {
245+ for (int i = 0 ; i < n ; ++ i ) L [i ] = 0 ;
246+ * the_min = - min ;
247+ return 0.f ;
248+ }
249+ float iscale = nmax /(max - min );
250+ float scale = 1 /iscale ;
251+ float best_mad = 0 ;
252+ for (int i = 0 ; i < n ; ++ i ) {
253+ int l = nearest_int (iscale * (x [i ] - min ));
254+ L [i ] = MAX (0 , MIN (nmax , l ));
255+ float diff = scale * L [i ] + min - x [i ];
256+ diff = use_mad ? fabsf (diff ) : diff * diff ;
257+ float w = weights [i ];
258+ best_mad += w * diff ;
259+ }
260+ if (nstep < 1 ) {
261+ * the_min = - min ;
262+ return scale ;
263+ }
264+ for (int is = 0 ; is <= nstep ; ++ is ) {
265+ iscale = (rmin + rdelta * is + nmax )/(max - min );
266+ float sum_l = 0 , sum_l2 = 0 , sum_xl = 0 ;
267+ for (int i = 0 ; i < n ; ++ i ) {
268+ int l = nearest_int (iscale * (x [i ] - min ));
269+ l = MAX (0 , MIN (nmax , l ));
270+ Laux [i ] = l ;
271+ float w = weights [i ];
272+ sum_l += w * l ;
273+ sum_l2 += w * l * l ;
274+ sum_xl += w * l * x [i ];
275+ }
276+ float D = sum_w * sum_l2 - sum_l * sum_l ;
277+ if (D > 0 ) {
278+ float this_scale = (sum_w * sum_xl - sum_x * sum_l )/D ;
279+ float this_min = (sum_l2 * sum_x - sum_l * sum_xl )/D ;
280+ if (this_min > 0 ) {
281+ this_min = 0 ;
282+ this_scale = sum_xl / sum_l2 ;
283+ }
284+ float mad = 0 ;
285+ for (int i = 0 ; i < n ; ++ i ) {
286+ float diff = this_scale * Laux [i ] + this_min - x [i ];
287+ diff = use_mad ? fabsf (diff ) : diff * diff ;
288+ float w = weights [i ];
289+ mad += w * diff ;
290+ }
291+ if (mad < best_mad ) {
292+ for (int i = 0 ; i < n ; ++ i ) {
293+ L [i ] = Laux [i ];
294+ }
295+ best_mad = mad ;
296+ scale = this_scale ;
297+ min = this_min ;
298+ }
299+ }
300+ }
301+ * the_min = - min ;
302+ return scale ;
303+ }
304+
266305#if QK_K == 256
267306static inline void get_scale_min_k4 (int j , const uint8_t * restrict q , uint8_t * restrict d , uint8_t * restrict m ) {
268307 if (j < 4 ) {
@@ -281,6 +320,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
281320 const int nb = k / QK_K ;
282321
283322 uint8_t L [QK_K ];
323+ uint8_t Laux [16 ];
324+ float weights [16 ];
284325 float mins [QK_K /16 ];
285326 float scales [QK_K /16 ];
286327
@@ -291,7 +332,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
291332 float max_scale = 0 ; // as we are deducting the min, scales are always positive
292333 float max_min = 0 ;
293334 for (int j = 0 ; j < QK_K /16 ; ++ j ) {
294- scales [j ] = make_qkx1_quants (16 , 3 , x + 16 * j , L + 16 * j , & mins [j ], 5 );
335+ for (int l = 0 ; l < 16 ; ++ l ) weights [l ] = fabsf (x [16 * j + l ]);
336+ scales [j ] = make_qkx2_quants (16 , 3 , x + 16 * j , weights , L + 16 * j , & mins [j ], Laux , -0.5f , 0.1f , 15 , true);
295337 float scale = scales [j ];
296338 if (scale > max_scale ) {
297339 max_scale = scale ;
@@ -637,6 +679,8 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
637679 const int nb = k / QK_K ;
638680
639681 uint8_t L [QK_K ];
682+ uint8_t Laux [32 ];
683+ float weights [32 ];
640684 float mins [QK_K /32 ];
641685 float scales [QK_K /32 ];
642686
@@ -645,7 +689,12 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
645689 float max_scale = 0 ; // as we are deducting the min, scales are always positive
646690 float max_min = 0 ;
647691 for (int j = 0 ; j < QK_K /32 ; ++ j ) {
648- scales [j ] = make_qkx1_quants (32 , 15 , x + 32 * j , L + 32 * j , & mins [j ], 5 );
692+ //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
693+ float sum_x2 = 0 ;
694+ for (int l = 0 ; l < 32 ; ++ l ) sum_x2 += x [32 * j + l ] * x [32 * j + l ];
695+ float av_x = sqrtf (sum_x2 /32 );
696+ for (int l = 0 ; l < 32 ; ++ l ) weights [l ] = av_x + fabsf (x [32 * j + l ]);
697+ scales [j ] = make_qkx2_quants (32 , 15 , x + 32 * j , weights , L + 32 * j , & mins [j ], Laux , -1.f , 0.1f , 20 , false);
649698 float scale = scales [j ];
650699 if (scale > max_scale ) {
651700 max_scale = scale ;
@@ -798,6 +847,8 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
798847 uint8_t L [QK_K ];
799848 float mins [QK_K /32 ];
800849 float scales [QK_K /32 ];
850+ float weights [32 ];
851+ uint8_t Laux [32 ];
801852#else
802853 int8_t L [QK_K ];
803854 float scales [QK_K /16 ];
@@ -810,7 +861,12 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
810861 float max_scale = 0 ; // as we are deducting the min, scales are always positive
811862 float max_min = 0 ;
812863 for (int j = 0 ; j < QK_K /32 ; ++ j ) {
813- scales [j ] = make_qkx1_quants (32 , 31 , x + 32 * j , L + 32 * j , & mins [j ], 5 );
864+ //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
865+ float sum_x2 = 0 ;
866+ for (int l = 0 ; l < 32 ; ++ l ) sum_x2 += x [32 * j + l ] * x [32 * j + l ];
867+ float av_x = sqrtf (sum_x2 /32 );
868+ for (int l = 0 ; l < 32 ; ++ l ) weights [l ] = av_x + fabsf (x [32 * j + l ]);
869+ scales [j ] = make_qkx2_quants (32 , 31 , x + 32 * j , weights , L + 32 * j , & mins [j ], Laux , -0.5f , 0.1f , 15 , false);
814870 float scale = scales [j ];
815871 if (scale > max_scale ) {
816872 max_scale = scale ;
0 commit comments