Skip to content

Conversation

@JohannesGaessler
Copy link
Collaborator

This PR aims to implement CUDA kernels for matrix vector multiplication that utilize dot products with quantized data instead of dequantizing the data on-the-fly. So far this is only implemented for q4_0. In order to get good performance integer intrinsics are used. Unfortunately these have very poor performance on Pascal cards so the current implementation with dequantization should be kept. For my RTX 3090 I found:

GPU Model Test t/s master t/s PR Speedup
RTX 3090 7b q4_0 tg128 91.06 101.39 1.11
RTX 3090 13b q4_0 tg128 51.88 57.95 1.12
RTX 3090 33b q4_0 tg128 22.83 25.71 1.13

For master I used the option LLAMA_CUDA_DMMV_F16 which uses f16 intrinsics for the calculation. Since this option is also only beneficial on relatively new cards and seemingly inferior to integer intrinsics I would suggest that the f16 option be removed in favor of this implementation.

@JohannesGaessler
Copy link
Collaborator Author

I forgot: because I'm changing the way quantization is used in this PR I would like to prioritize it over #2043 and then think about how to approach the dequantization for that PR again.

@slaren
Copy link
Member

slaren commented Jul 1, 2023

This is probably not going to make much of a difference in practice, but the __syncthreads or __syncwarp before the warp shuffles shouldn't be necessary, since these already imply a sync (at least since they gained the _sync suffix, it wasn't always the case). See this for more details: https://developer.nvidia.com/blog/using-cuda-warp-level-primitives/

All the participating threads must be synchronized for the collective operation to work correctly. Therefore, these primitives first synchronize the threads if they are not already synchronized.

You may also gain some additional performance if instead of quantizing the vector to DRAM, you do it to shared memory at the beginning of vec_dot_q. It should be small enough to fit, in most cases at least.

@JohannesGaessler
Copy link
Collaborator Author

You may also gain some additional performance if instead of quantizing the vector to DRAM, you do it to shared memory at the beginning of vec_dot_q. It should be small enough to fit, in most cases at least.

The problem is that the vector is loaded thousands of times by different blocks. So I think that dequantizing once and then writing the dequantized version to DRAM is faster than dequantizing thousands of times.

@slaren
Copy link
Member

slaren commented Jul 1, 2023

You could use one block only and a lot more threads, and compute each row in a different warp. That worked for me in some tests I have been doing with the attention, but these matrices/vectors are very small, and it may not work so well for other (larger) matrix-vector multiplications.

@JohannesGaessler
Copy link
Collaborator Author

If I remember correctly the maximum block size is 1024 threads/32 warps. So for 7b where the smallest matrix has 4096 rows that would still mean quantizing the vector 128 times. Currently the quantization to q8_0 takes up 2.7% of the runtime so I don't think doing that several times is going to be viable.

@slaren
Copy link
Member

slaren commented Jul 1, 2023

If you adjust the block and grid size so that all blocks can be executed simultaneously, it may not matter that you have to quantize in each block, since it will be done simultaneously anyway. That may not work if the number of blocks is higher than the capacity of the GPU, but in that case you can still compute multiple rows in each warp and adjust the number of blocks accordingly. The 3090 can execute 1536 threads per SM, so a block size of 768 to fit two blocks in each SM may work best.

@Midaychi
Copy link

Midaychi commented Jul 2, 2023

Unfortunately these have very poor performance on Pascal cards so the current implementation with dequantization should be kept.

If you use fp32 based operations on pascal cards instead of fp16 it should have much better performance

@JohannesGaessler
Copy link
Collaborator Author

I'm not using any f16 intrinsics. The option for that is already on master. I'm using __vsub4 and __dp4a to do byte-wise subtractions and dot products on integers.

@JohannesGaessler
Copy link
Collaborator Author

I have implemented a kernel for q4_1 and to my surprise I've found that the performance is ~10% better than for q4_0. The reason seems to be that due to q4_1 having a size of 20 bits vs. the 18 bits of q4_0 it is possible to directly cast the pointer for the quants to int instead of having to resort to memcpy. Since I'm currently still using memcpy for q8_0 this implies that performance could be significantly improved by padding or reordering the q8_0 vector; I'll investigate.

More generally this may also mean that reordering the weights in some way may be of benefit after all.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jul 3, 2023

I pushed a version in which the vector is quantized to q8_1 (36 bytes) instead of q8_0 (34 bytes). This allows you to directly cast the quant int 8 pointers to int 32 pointers which is significantly faster. With this I get 123 t/s for q4_0 using an RTX 3090. Reordering the data so that the scales and quants are in two separate blocks seems to have slightly worse performance, presumably due to cache locality.

@casper-hansen
Copy link

I have implemented a kernel for q4_1 and to my surprise I've found that the performance is ~10% better than for q4_0. The reason seems to be that due to q4_1 having a size of 20 bits vs. the 18 bits of q4_0 it is possible to directly cast the pointer for the quants to int instead of having to resort to memcpy. Since I'm currently still using memcpy for q8_0 this implies that performance could be significantly improved by padding or reordering the q8_0 vector; I'll investigate.

More generally this may also mean that reordering the weights in some way may be of benefit after all.

GPTQ implements a reordering approach based on quantization error. Weights with the smallest error first and weights with largest error last.

Not sure if it’s possible to achieve in llama.cpp - side effect in GPTQ seemed to be performance issues.

@JohannesGaessler
Copy link
Collaborator Author

I don't mean changing the order of the weights itself, I mean changing the way the data is laid out for better memory alignment.

@JohannesGaessler
Copy link
Collaborator Author

I pushed implementations for q5_0, q5_1, and q8_0. I think I've done the low-hanging fruits in terms of performance so I think I'll focus on making the new features usable now. Since the integer intrinsics seem to rely on hardware implementations I think I'll enable them based on compute capability. Ideally I can just set two compute capabilities in cmake and it will automatically use the highest one that a particular GPU supports.

@JohannesGaessler
Copy link
Collaborator Author

@slaren do you think we should keep the dequantize_mul_mat_vec implementations using f16 intrinsics? They were slightly faster on recent NVIDIA cards but the integer intrinsics seem to be superior for those cases.

@slaren
Copy link
Member

slaren commented Jul 4, 2023

I think that can still be useful for f16 models, so I would say keep it.

@JohannesGaessler JohannesGaessler marked this pull request as ready for review July 4, 2023 16:20
@JohannesGaessler
Copy link
Collaborator Author

Alright, I now consider this ready to be merged. By default the new kernels are used (if the compute capability is high enough), the old DMMV kernels can still be used by setting LLAMA_CUDA_FORCE_DMMV. These are the final performance numbers on my system:

GPU Model Test t/s master t/s PR Speedup
RTX 3090 7b q4_0 tg128 90.40 121.52 1.34
RTX 3090 13b q4_0 tg128 51.32 69.23 1.35
RTX 3090 33b q4_0 tg128 22.65 31.91 1.41
RTX 3090 7b q4_1 tg128 84.30 115.00 1.36
RTX 3090 7b q4_0 tg128 60.75 103.35 1.70
RTX 3090 7b q5_1 tg128 60.69 99.08 1.63
RTX 3090 7b q8_0 tg128 72.89 77.55 1.06

CMakeLists.txt Outdated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the lowest for the integer intrinsics is 70 in practice, I think this could be changed too, if only for clarity.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For single GPU I would agree but for multi GPU settings that would be an issue. If you were to combine e.g. a Pascal and an Ampere card you would want to use the integer intrinsics with the 8.6 Ampere card (but not the 6.1 Pascal card). The decision which implementation to use can be done at runtime by checking the compute capability per card but only if the integer intrinsics are available at compile time.

@JohannesGaessler JohannesGaessler merged commit 924dd22 into ggml-org:master Jul 5, 2023
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice speed-up! 🦙

My guess is that a similar approach for qmat x qmat should result in better performance than the existing mat x mat using cuBLAS.

@mirek190
Copy link

mirek190 commented Jul 5, 2023

is possible improve like that q_K models?

@JohannesGaessler
Copy link
Collaborator Author

It is very likely possible to apply the same techniques to q_K models. The reason I didn't do it is merely that the CUDA implementation for those was done very differently compared to the older quantization methods which use a template. So I would rather work out all of the details on the older quantization methods before I touch half a dozen different k-quant implementations.

@mirek190
Copy link

mirek190 commented Jul 5, 2023

I am asking because q_K4_m has very similar perplexity to q5_1 ... BUT 33B 63 layers model q5_1 we cannot put entirely on consumer GPU ( RTX 3090, 4090 with 24 GB ) on the other hand q_K4_m is fitting perfectly where I have 18.5T/s ... thinking I COULD get something close 30 T/s with 33B and q4K_m .... just OMG

@JohannesGaessler
Copy link
Collaborator Author

Sorry, but you'll just need to be patient.

@LostRuins
Copy link
Collaborator

Ever since this was merged, I am getting rubbish outputs when using CUDA (ref #2136).

The outputs are normal if GGML_CUDA_FORCE_DMMV is set to true, or if 0 layers are offloaded. Otherwise, it ranges from a mix of garbled tokens to just a single repeated token.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants