Skip to content

Conversation

@ikawrakow
Copy link
Contributor

@ikawrakow ikawrakow commented Aug 22, 2023

TL;DR: A more accurate implementation of perplexity calculation. Triggered via ./perplexity ... --ppl_stride N.

It has been a long-standing mystery why PPL values computed with llama.cpp are higher than those coming from the Python universe. For instance, for LLaMA-v1-7B, the llama.cpp perplexity for the fp16 model is ~5.91, while the value frequently quoted around the Internet is ~5.68. One prominent hypothesis has been that the difference may be due to differences in the tokenization. But even after merging the GGUF changes and the improvements in the tokenizer, llama.cpp perplexities are largely unchanged.

I think this PR resolves the mystery: it is simply a matter of the perplexity computation not being done correctly in llama.cpp. The existing implementation evaluates chunks of n_ctx (the context size), and then averages the log probabilities in the range n_ctx/2...n_ctx. This is roughly equivalent to a strided perplexity calculation for an effective context length of 3/4 * n_ctx with a stride of n_ctx/2, see e.g. this HF post. It is "roughly", because this basically skips half of the tokens during evaluation.

An exact, token-by-token perplexity evaluation would be too expensive to be of practical use. There are ~330,000 tokens in Wikitext, each token takes ~8 ms on a high-end GPU, so one needs in the range of ~45 minutes to evaluate (compared to the current ~3.5 minutes for 7B). But even if one wanted to implement that way, this is currently not possible with ggml because one can not modify the positional encoding (to be able to do that, one would need to change ggml to store embeddings into the KV-cache before RoPE has been applied, and then apply RoPE as needed, given the current number of past tokens). Hence, this PR adds a strided implementation as discussed in the HF post. The new implementation is triggered using --ppl-stride N, where N is some integer. The smaller N is, the more accurate the result, and the longer the computation takes (see below for actual usage examples). The implementation then increases the context n_ctx supplied by the user by N/2, evaluates chunks of n_ctx + N/2 tokens, adds log probabilities from the n_ctx - N/2 ... n_ctx + N/2 token range, moves N tokens ahead, and repeats until all tokens have been processes.

For now I have kept the original implementation and added a new function perplexity_v2() instead of trying to fit in the new logic into the existing implementation. This will make it easier to remove one of the implementations in the future.

If the change of storing un-RoPE'ed embeddings in the KV-cache was made in ggml, then one would have a much more efficient implementation. One would first evaluate n_ctx + N/2 tokens, and then one would keep evaluating N tokens at a time with n_past = n_ctx.

The following table gives compute times in seconds (using CUDA on a 4080 GPU) and perplexities obtained with LLaMA-v1-7B:

Stride Time PPL
Master 195 5.9066
512 347 5.7668
256 598 5.7291
128 1100 5.7226

We see that the strided approach gives a PPL that is much closer to the Python PPL (although, it looks unlikely it will converge to 5.68, so there must be still differences in the implementation).

The following graph shows PPL as a function of the number of tokens evaluated so far for LLaMA-v2-7B. Based on this, a stride of 128 is likely to be good enough.
strided_ppl

@ikawrakow ikawrakow requested a review from ggerganov August 22, 2023 15:22
@klosax
Copy link
Contributor

klosax commented Aug 22, 2023

Instead of overwriting the tokens with BOS, properly inserting it will lower the ppl. I tested this during the BOS change in PR #1303

@ikawrakow
Copy link
Contributor Author

Instead of overwriting the tokens with BOS, properly inserting it will lower the ppl. I tested this during the BOS change in PR #1303

This PR uses the exact same approach for BOS as what we currently have in the original perplexity calculation (which is the outcome of #1303). Now, if inserting a BOS instead of replacing the first token in a context window with a BOS made a significant difference, that would basically mean that PPL for a context length of n_ctx+1 is significantly different from a PPL with a context length of n_ctx. While this is certainly the case when n_ctx is very small, the difference is within the noise for n_ctx = 512 (I have verified by running a few tests).

@ikawrakow ikawrakow merged commit 62959e7 into master Aug 23, 2023
@ikawrakow ikawrakow deleted the ik/better_perplexity branch August 23, 2023 09:56
Comment on lines +421 to +432
} else if (arg == "--ppl-stride") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.ppl_stride = std::stoi(argv[i]);
} else if (arg == "--ppl-output-type") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.ppl_output_type = std::stoi(argv[i]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two parameters do not appear in the --help. I assume this is a simple oversight?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not add it to the help on purpose. As you can see from the table above, the current implementation is very inefficient. Hence, I decided that for now it is better to have this option kind of hidden from general usage and available only to those who pay attention to the commits. Later, when a better handling of RoPE becomes available (I discussed this with @ggerganov and this is on his radar), the implementation can be improved to be (almost) as efficient as the original llama.cpp implementation. My plan was to add the option to the help at that point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as we discussed - we need more efficient KV cache reuse.
Will track this in the existing issue on the roadmap: #2060

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, now that the KV changes got merged it should be possible to update the new perplexity calculation to perform about the same as the original. Right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants