@@ -757,8 +757,8 @@ struct llama_server_context
757757 result.text_to_send = slot.generated_text .substr (pos, std::string::npos);
758758 slot.sent_count += result.text_to_send .size ();
759759 // add the token to slot queue and cache
760- slot.addTokenString (result);
761760 }
761+ slot.addTokenString (result);
762762 if (slot.multibyte_pending > 0 )
763763 {
764764 slot.multibyte_pending -= token_str.size ();
@@ -925,8 +925,8 @@ struct llama_server_context
925925 }
926926
927927 // context shift takes effect only when there is a single slot
928- if (slots. size () == 1 ) {
929- llama_client_slot slot = slots[0 ];
928+ if (params. n_parallel == 1 ) {
929+ llama_client_slot & slot = slots[0 ];
930930 if (slot.isProcessing () && slot.cache_tokens .size () >= (size_t )n_ctx)
931931 {
932932 // Shift context
@@ -1028,22 +1028,16 @@ struct llama_server_context
10281028
10291029 slot.num_prompt_tokens = prompt_tokens.size ();
10301030
1031- slot.n_past = slot.params .cache_prompt ? common_part (slot.cache_tokens , prompt_tokens) : 0 ;
1032-
1033- slot.cache_tokens = prompt_tokens;
1034-
1035- if (slot.n_past == slot.num_prompt_tokens ) {
1036- // we have to evaluate at least 1 token to generate logits.
1037- printf (" we have to evaluate at least 1 token to generate logits\n " );
1038- slot.n_past --;
1039- }
1040-
1041- slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past ;
1042-
1043- if (!slot.params .cache_prompt ) {
1031+ if (!slot.params .cache_prompt ) {
10441032 std::fill (slot.last_n_tokens .begin (), slot.last_n_tokens .end (), 0 );
1033+ slot.n_past = 0 ;
1034+ slot.num_prompt_tokens_processed = slot.num_prompt_tokens ;
10451035 } else {
1046- LOG_TEE (" slot %i - in cache: %i tokens | to process: %i tokens\n " , slot.id , slot.n_past , slot.num_prompt_tokens_processed );
1036+ if (params.n_keep < 0 && params.n_parallel == 1 )
1037+ {
1038+ params.n_keep = (int )slot.num_prompt_tokens ;
1039+ }
1040+ params.n_keep = std::min (params.n_ctx - 4 , params.n_keep );
10471041 // if input prompt is too big, truncate like normal
10481042 if (slot.num_prompt_tokens >= (size_t )n_ctx)
10491043 {
@@ -1059,14 +1053,26 @@ struct llama_server_context
10591053 });
10601054 slot.truncated = true ;
10611055 prompt_tokens = new_tokens;
1056+ slot.num_prompt_tokens = prompt_tokens.size ();
10621057 }
10631058 const size_t ps = slot.num_prompt_tokens ;
10641059 std::fill (slot.last_n_tokens .begin (), slot.last_n_tokens .end () - ps, 0 );
10651060 std::copy (prompt_tokens.begin (), prompt_tokens.end (), slot.last_n_tokens .end () - ps);
1061+ slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
1062+ slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past ;
1063+ LOG_TEE (" slot %i - in cache: %i tokens | to process: %i tokens\n " , slot.id , slot.n_past , slot.num_prompt_tokens_processed );
10661064 }
10671065
10681066 llama_kv_cache_seq_rm (ctx, slot.id , num_tokens_system + slot.n_past , -1 );
10691067
1068+ slot.cache_tokens = prompt_tokens;
1069+
1070+ if (slot.n_past == slot.num_prompt_tokens ) {
1071+ // we have to evaluate at least 1 token to generate logits.
1072+ printf (" we have to evaluate at least 1 token to generate logits\n " );
1073+ slot.n_past --;
1074+ }
1075+
10701076 LOG_VERBOSE (" prompt ingested" , {
10711077 {" n_past" , slot.n_past },
10721078 {" cached" , tokens_to_str (ctx, slot.cache_tokens .cbegin (), slot.cache_tokens .cbegin () + slot.n_past )},
@@ -1185,7 +1191,7 @@ struct llama_server_context
11851191 }
11861192 }
11871193
1188- if (kv_cache_free < 0 ) {
1194+ if (kv_cache_free < 0 && params. n_parallel > 1 ) {
11891195 LOG_TEE (" \n Error: kv cache is full, increase context size." );
11901196 return false ;
11911197 }
@@ -1581,6 +1587,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
15811587 }
15821588}
15831589
1590+ static void slot_print_timings (struct llama_client_slot * slot) {
1591+ LOG_TEE (" \n " );
1592+ LOG_TEE (" %s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n " ,
1593+ __func__, slot->t_prompt_processing , slot->num_prompt_tokens_processed , slot->t_prompt_processing / slot->num_prompt_tokens_processed , 1e3 / slot->t_prompt_processing * slot->num_prompt_tokens_processed );
1594+ LOG_TEE (" %s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n " ,
1595+ __func__, slot->t_token_generation , slot->n_decoded , slot->t_token_generation / slot->n_decoded , 1e3 / slot->t_token_generation * slot->n_decoded );
1596+ LOG_TEE (" %s: total time = %10.2f ms\n " , __func__, slot->t_prompt_processing + slot->t_token_generation );
1597+ }
1598+
15841599static json format_generation_settings (llama_server_context &llama, llama_client_slot* slot)
15851600{
15861601 const auto eos_bias = slot->sparams .logit_bias .find (llama_token_eos (llama.ctx ));
@@ -1606,7 +1621,7 @@ static json format_generation_settings(llama_server_context &llama, llama_client
16061621 {" penalize_nl" , slot->sparams .penalize_nl },
16071622 {" stop" , slot->params .antiprompt },
16081623 {" n_predict" , slot->params .n_predict },
1609- // {"n_keep", slot .params.n_keep},
1624+ {" n_keep" , llama .params .n_keep },
16101625 {" ignore_eos" , ignore_eos},
16111626 {" stream" , slot->params .stream },
16121627 {" logit_bias" , slot->sparams .logit_bias },
@@ -1730,7 +1745,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
17301745 slot->sparams .mirostat_tau = json_value (body, " mirostat_tau" , default_sparams.mirostat_tau );
17311746 slot->sparams .mirostat_eta = json_value (body, " mirostat_eta" , default_sparams.mirostat_eta );
17321747 slot->sparams .penalize_nl = json_value (body, " penalize_nl" , default_sparams.penalize_nl );
1733- llama.params .n_keep = json_value (body, " n_keep" , - 1 );
1748+ llama.params .n_keep = json_value (body, " n_keep" , 0 );
17341749 slot->params .seed = json_value (body, " seed" , default_params.seed );
17351750 slot->params .grammar = json_value (body, " grammar" , default_params.grammar );
17361751 slot->sparams .n_probs = json_value (body, " n_probs" , default_sparams.n_probs );
@@ -2089,6 +2104,7 @@ int main(int argc, char **argv)
20892104 }
20902105
20912106 const json data = format_final_response (llama, slot, completion_text, probs);
2107+ slot_print_timings (slot);
20922108 slot->release ();
20932109 res.set_content (data.dump (-1 , ' ' , false , json::error_handler_t ::replace),
20942110 " application/json" );
@@ -2131,6 +2147,7 @@ int main(int argc, char **argv)
21312147 slot->generated_token_probs .begin (),
21322148 slot->generated_token_probs .begin () + sent_token_probs_index)
21332149 );
2150+ slot_print_timings (slot);
21342151 const std::string str =
21352152 " data: " +
21362153 data.dump (-1 , ' ' , false , json::error_handler_t ::replace) +
@@ -2197,6 +2214,7 @@ int main(int argc, char **argv)
21972214 }
21982215
21992216 const json data = format_final_response (llama, slot, completion_text, probs);
2217+ slot_print_timings (slot);
22002218 res.set_content (data.dump (-1 , ' ' , false , json::error_handler_t ::replace),
22012219 " application/json" );
22022220 } else {
@@ -2238,6 +2256,7 @@ int main(int argc, char **argv)
22382256 slot->generated_token_probs .begin (),
22392257 slot->generated_token_probs .begin () + sent_token_probs_index)
22402258 );
2259+ slot_print_timings (slot);
22412260 const std::string str =
22422261 " data: " +
22432262 data.dump (-1 , ' ' , false , json::error_handler_t ::replace) +
0 commit comments