Skip to content

Commit a8e39f6

Browse files
committed
return pre-sampling p
1 parent c9148ba commit a8e39f6

File tree

2 files changed

+137
-72
lines changed

2 files changed

+137
-72
lines changed

‎examples/server/server.cpp‎

Lines changed: 47 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,26 +1301,7 @@ struct server_context {
13011301
}
13021302

13031303
// check if there is incomplete UTF-8 character at the end
1304-
bool incomplete = false;
1305-
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
1306-
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
1307-
if ((c & 0xC0) == 0x80) {
1308-
// continuation byte: 10xxxxxx
1309-
continue;
1310-
}
1311-
if ((c & 0xE0) == 0xC0) {
1312-
// 2-byte character: 110xxxxx ...
1313-
incomplete = i < 2;
1314-
} else if ((c & 0xF0) == 0xE0) {
1315-
// 3-byte character: 1110xxxx ...
1316-
incomplete = i < 3;
1317-
} else if ((c & 0xF8) == 0xF0) {
1318-
// 4-byte character: 11110xxx ...
1319-
incomplete = i < 4;
1320-
}
1321-
// else 1-byte character or invalid byte
1322-
break;
1323-
}
1304+
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
13241305

13251306
if (!incomplete) {
13261307
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
@@ -1416,6 +1397,33 @@ struct server_context {
14161397
return slot.has_next_token; // continue
14171398
}
14181399

1400+
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) {
1401+
size_t n_probs = slot.sparams.n_probs;
1402+
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
1403+
1404+
// TODO: optimize this with min-p optimization
1405+
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
1406+
1407+
// set probability for sampled token
1408+
for (size_t i = 0; i < n_vocab; i++) {
1409+
// set probability for sampled token
1410+
if (cur[i].id == result.tok) {
1411+
result.prob = cur[i].p;
1412+
break;
1413+
}
1414+
}
1415+
1416+
// set probability for top n_probs tokens
1417+
result.probs.reserve(n_probs);
1418+
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
1419+
result.probs.push_back({
1420+
cur[i].id,
1421+
llama_detokenize(ctx, {cur[i].id}, special),
1422+
cur[i].p
1423+
});
1424+
}
1425+
}
1426+
14191427
json get_formated_generation(const server_slot & slot) const {
14201428
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
14211429
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
@@ -1507,19 +1515,7 @@ struct server_context {
15071515
};
15081516

15091517
if (slot.sparams.n_probs > 0) {
1510-
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
1511-
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
1512-
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
1513-
1514-
std::vector<completion_token_output> probs_output;
1515-
if (probs_pos < probs_stop_pos) {
1516-
probs_output = std::vector<completion_token_output>(
1517-
slot.generated_token_probs.begin() + probs_pos,
1518-
slot.generated_token_probs.begin() + probs_stop_pos);
1519-
}
1520-
slot.n_sent_token_probs = probs_stop_pos;
1521-
1522-
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
1518+
res.data["completion_probabilities"] = probs_vector_to_json(ctx, {tkn});
15231519
}
15241520

15251521
if (slot.oaicompat) {
@@ -1559,7 +1555,7 @@ struct server_context {
15591555
{"timings", slot.get_formated_timings()}
15601556
};
15611557

1562-
if (slot.sparams.n_probs > 0) {
1558+
if (!slot.params.stream && slot.sparams.n_probs > 0) {
15631559
std::vector<completion_token_output> probs;
15641560
if (!slot.params.stream && slot.stopped_word) {
15651561
const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
@@ -2513,7 +2509,8 @@ struct server_context {
25132509
}
25142510

25152511
completion_token_output result;
2516-
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
2512+
const int tok_idx = slot.i_batch - i;
2513+
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, tok_idx);
25172514

25182515
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
25192516

@@ -2526,32 +2523,10 @@ struct server_context {
25262523

25272524
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
25282525
result.tok = id;
2526+
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
25292527

2530-
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
2531-
if (n_probs > 0) {
2532-
const size_t n_valid = slot.ctx_sampling->n_valid;
2533-
2534-
// Make sure at least n_probs top tokens are at the front of the vector:
2535-
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
2536-
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
2537-
}
2538-
2539-
if (slot.sparams.temp == 0.0f) {
2540-
// With greedy sampling the probabilities have possibly not been calculated.
2541-
for (size_t i = 0; i < n_probs; ++i) {
2542-
result.probs.push_back({
2543-
cur_p.data[i].id,
2544-
i == 0 ? 1.0f : 0.0f
2545-
});
2546-
}
2547-
} else {
2548-
for (size_t i = 0; i < n_probs; ++i) {
2549-
result.probs.push_back({
2550-
cur_p.data[i].id,
2551-
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
2552-
});
2553-
}
2554-
}
2528+
if (slot.sparams.n_probs > 0) {
2529+
populate_token_probs(slot, result, params.special, tok_idx);
25552530
}
25562531

25572532
if (!process_token(result, slot)) {
@@ -2601,6 +2576,12 @@ static json format_final_response_oaicompat(const json& request, json result, co
26012576
{"message", json{{"content", content},
26022577
{"role", "assistant"}}}} });
26032578

2579+
if (result.contains("completion_probabilities")) {
2580+
choices[0]["logprobs"] = json{
2581+
{"content", json_value(result, "completion_probabilities", json::array())},
2582+
};
2583+
}
2584+
26042585
std::time_t t = std::time(0);
26052586

26062587
json res = json{
@@ -2621,10 +2602,6 @@ static json format_final_response_oaicompat(const json& request, json result, co
26212602
res["__verbose"] = result;
26222603
}
26232604

2624-
if (result.contains("completion_probabilities")) {
2625-
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
2626-
}
2627-
26282605
return res;
26292606
}
26302607

@@ -2712,6 +2689,12 @@ static std::vector<json> format_partial_response_oaicompat(server_task_result ta
27122689
}
27132690
}
27142691

2692+
if (result.contains("completion_probabilities")) {
2693+
choices[0]["logprobs"] = json{
2694+
{"content", json_value(result, "completion_probabilities", json::array())},
2695+
};
2696+
}
2697+
27152698
json ret = json{
27162699
{"choices", choices},
27172700
{"created", t},

‎examples/server/utils.hpp‎

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,36 @@ static inline void server_log(const char * level, const char * function, int lin
111111
fflush(stdout);
112112
}
113113

114+
// return the last index of character that can form a valid string
115+
// if the last character is potentially cut in half, return the index before the cut
116+
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
117+
static size_t validate_utf8(const std::string& text) {
118+
size_t len = text.size();
119+
if (len == 0) return 0;
120+
121+
// Check the last few bytes to see if a multi-byte character is cut off
122+
for (size_t i = 1; i <= 4 && i <= len; ++i) {
123+
unsigned char c = text[len - i];
124+
// Check for start of a multi-byte sequence from the end
125+
if ((c & 0xE0) == 0xC0) {
126+
// 2-byte character start: 110xxxxx
127+
// Needs at least 2 bytes
128+
if (i < 2) return len - i;
129+
} else if ((c & 0xF0) == 0xE0) {
130+
// 3-byte character start: 1110xxxx
131+
// Needs at least 3 bytes
132+
if (i < 3) return len - i;
133+
} else if ((c & 0xF8) == 0xF0) {
134+
// 4-byte character start: 11110xxx
135+
// Needs at least 4 bytes
136+
if (i < 4) return len - i;
137+
}
138+
}
139+
140+
// If no cut-off multi-byte character is found, return full length
141+
return len;
142+
}
143+
114144
//
115145
// chat template utils
116146
//
@@ -307,16 +337,31 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx,
307337

308338
struct completion_token_output {
309339
llama_token tok;
340+
float prob;
310341
std::string text_to_send;
311342

312-
struct token_prob {
343+
struct prob_info {
313344
llama_token tok;
345+
std::string txt;
314346
float prob;
315347
};
316348

317-
std::vector<token_prob> probs;
349+
std::vector<prob_info> probs;
318350
};
319351

352+
static float logarithm(float x) {
353+
// nlohmann::json converts -inf to null, so we need to prevent that
354+
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
355+
}
356+
357+
static std::vector<unsigned char> str_to_bytes(const std::string & str) {
358+
std::vector<unsigned char> bytes;
359+
for (unsigned char c : str) {
360+
bytes.push_back(c);
361+
}
362+
return bytes;
363+
}
364+
320365
// convert a vector of completion_token_output to json
321366
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
322367
json out = json::array();
@@ -325,17 +370,24 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
325370
json probs_for_token = json::array();
326371

327372
for (const auto & p : prob.probs) {
328-
const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
373+
std::string txt(p.txt);
374+
txt.resize(validate_utf8(txt));
329375
probs_for_token.push_back(json {
330-
{"tok_str", tok_str},
331-
{"prob", p.prob},
376+
{"id", p.tok},
377+
{"token", txt},
378+
{"bytes", str_to_bytes(p.txt)},
379+
{"logprob", logarithm(p.prob)},
332380
});
333381
}
334382

335-
const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
383+
std::string txt(prob.text_to_send);
384+
txt.resize(validate_utf8(txt));
336385
out.push_back(json {
337-
{"content", tok_str},
338-
{"probs", probs_for_token},
386+
{"id", prob.tok},
387+
{"token", txt},
388+
{"bytes", str_to_bytes(prob.text_to_send)},
389+
{"logprob", logarithm(prob.prob)},
390+
{"top_logprobs", probs_for_token},
339391
});
340392
}
341393

@@ -463,3 +515,33 @@ static json format_error_response(const std::string & message, const enum error_
463515
{"type", type_str},
464516
};
465517
}
518+
519+
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
520+
std::vector<llama_token_data> cur;
521+
const auto * logits = llama_get_logits_ith(ctx, idx);
522+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
523+
524+
cur.resize(n_vocab);
525+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
526+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
527+
}
528+
529+
// sort tokens by logits
530+
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
531+
return a.logit > b.logit;
532+
});
533+
534+
// apply softmax
535+
float max_l = cur[0].logit;
536+
float cum_sum = 0.0f;
537+
for (size_t i = 0; i < cur.size(); ++i) {
538+
float p = expf(cur[i].logit - max_l);
539+
cur[i].p = p;
540+
cum_sum += p;
541+
}
542+
for (size_t i = 0; i < cur.size(); ++i) {
543+
cur[i].p /= cum_sum;
544+
}
545+
546+
return cur;
547+
}

0 commit comments

Comments
 (0)