@@ -1287,6 +1287,27 @@ static llama_vocab::id llama_sample_top_p_top_k(
12871287// quantization
12881288//
12891289
1290+ #include " ggml_internal.h"
1291+
1292+ struct error_stats {
1293+ size_t num_samples;
1294+ double total_error;
1295+ double max_error;
1296+ };
1297+
1298+ static void update_error_stats (int64_t nelements, const float * input, const float * output, error_stats & stats) {
1299+ for (int64_t i = 0 ; i < nelements; i++) {
1300+ double diff = input[i] - output[i];
1301+ stats.total_error += diff * diff;
1302+ stats.max_error = fmax (fabs (diff), stats.max_error );
1303+ }
1304+ stats.num_samples += nelements;
1305+ }
1306+
1307+ static void print_error_stats (const std::string & name, const error_stats & stats) {
1308+ printf (" %-50s: mse %.8f, maxerr %.8f\n " , name.c_str (), stats.total_error / (double ) stats.num_samples , stats.max_error );
1309+ }
1310+
12901311// TODO: reuse code from the llama_model_load() somehow
12911312static bool llama_model_quantize_internal (const std::string & fname_inp, const std::string & fname_out, int itype) {
12921313 ggml_type type = GGML_TYPE_Q4_1;
@@ -1312,10 +1333,17 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
13121333 return false ;
13131334 }
13141335
1315- auto fout = std::ofstream (fname_out, std::ios::binary);
1316- if (!fout) {
1317- fprintf (stderr, " %s: failed to open '%s' for writing\n " , __func__, fname_out.c_str ());
1318- return false ;
1336+ bool stats = fname_out.empty ();
1337+ error_stats total_error {};
1338+ std::vector<float > output_scratch;
1339+
1340+ std::ofstream fout;
1341+ if (!stats) {
1342+ fout.open (fname_out, std::ios::binary);
1343+ if (!fout) {
1344+ fprintf (stderr, " %s: failed to open '%s' for writing\n " , __func__, fname_out.c_str ());
1345+ return false ;
1346+ }
13191347 }
13201348
13211349 // verify magic
@@ -1549,6 +1577,15 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
15491577 printf (" %5.3f " , hist_cur[i] / float (nelements));
15501578 }
15511579 printf (" \n " );
1580+
1581+ if (stats && !std::regex_match (name, std::regex (" norm" ))) {
1582+ quantize_fns_t qfns = ggml_internal_get_quantize_fn (type);
1583+ #define QK 32
1584+ assert (nelements % QK == 0 );
1585+ output_scratch.resize (nelements);
1586+ qfns.dequantize_row_q (work.data (), output_scratch.data (), nelements);
1587+ update_error_stats (nelements, data_f32.data (), output_scratch.data (), total_error);
1588+ }
15521589 } else {
15531590 printf (" size = %8.3f MB\n " , data_u8.size ()/1024.0 /1024.0 );
15541591 fout.write (reinterpret_cast <char *>(data_u8.data ()), data_u8.size ());
@@ -1578,6 +1615,11 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
15781615 finp.close ();
15791616 fout.close ();
15801617
1618+ if (stats) {
1619+ static const char * ggml_type_str[] = { " q4_0" , " q4_1" , };
1620+ print_error_stats (ggml_type_str[type], total_error);
1621+ }
1622+
15811623 return true ;
15821624}
15831625
0 commit comments