Skip to content

Conversation

@lshzh-ww
Copy link
Contributor

@lshzh-ww lshzh-ww commented Aug 14, 2023

This commit removes MetalPerformanceShaders and uses custom matrix-matrix multiplication kernels for all quantization types. This commit also adds grouped-query attention to support llama2 70B model.

Image Image

When performing matrix multiplication involving a 4096x4096 matrix and a 4096xne1 matrix, our custom kernel demonstrates an utilization of approximately 88% of the hardware ALU. However, when profiling the llama.cpp code using a 256-tok evaluation, the avg utilization stands at around 40%. The time consumed by non-matrix multiplication processes is considerably extensive, but that's beyond of the scope of this PR.

M1 Max 32c GPU

model PR
llama 7B f16 316 tok/s
llama 7B q4_0 297 tok/s

We encourage people to thoroughly test this branch.

@lshzh-ww lshzh-ww requested a review from ggerganov August 14, 2023 18:38
@lshzh-ww lshzh-ww force-pushed the metal-mat-mul branch 2 times, most recently from 19bbd3b to 6b22562 Compare August 14, 2023 18:46
@lshzh-ww
Copy link
Contributor Author

Test command (256 toks in total):

Details

./main -m ~/model_file -c 512 -s 12 -p "I believe that I believe what we all want in life is to be happy. I truly do believe that. In my pursuit of happiness and fulfillment, I have found that traveling, meeting new people, and experiencing different cultures is at the top of my list. I believe that I believe what we all want in life is to be happy. I truly do believe that. In my pursuit of happiness and fulfillment, I have found that traveling, meeting new people, and experiencing different cultures is at the top of my list. I believe that I believe what we all want in life is to be happy. I truly do believe that. In my pursuit of happiness and fulfillment, I have found that traveling, meeting new people, and experiencing different cultures is at the top of my list. I believe that I believe what we all want in life is to be happy. I truly do believe that. In my pursuit of happiness and fulfillment, I have found that traveling, meeting new people, and experiencing different cultures is at the top of my list. I believe that I believe what we all want in life is to be happy. I truly do believe that. " --ignore-eos -ngl 1 -n 2 --no-mmap -t 8

@jhen0409
Copy link
Collaborator

M2 10c GPU with Llama 2 7B q4_0: 100 tok/s (42 tok/s before)

This commit removes MPS and uses custom matrix-matrix multiplication
kernels for all quantization types. This commit also adds grouped-query
attention to support llama2 70B.
Integers are slow on the GPU, and 64-bit divides are extremely slow.
In the context of GQA, we introduce a 64-bit divide that cannot be
optimized out by the compiler, which results in a decrease of ~8% in
inference performance. This commit fixes that issue by calculating a
part of the offset with a 32-bit divide. Naturally, this limits the
size of a single matrix to ~4GB. However, this limitation should
suffice for the near future.
@ggerganov
Copy link
Member

I get ~672 t/s for PP, but I think there is something wrong in the computation.
When I run the perplexity tool, the values are quite high:

$LLAMA_METAL=1 make -j && ./perplexity -m models/7B/ggml-model-q4_0.bin -f build/wikitext-2-raw/wiki.test.raw -ngl 1
I llama.cpp build info: 
I UNAME_S:  Darwin
I UNAME_P:  arm
I UNAME_M:  arm64
I CFLAGS:   -I.              -O3 -std=c11   -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -pthread -DGGML_USE_K_QUANTS -DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_METAL_NDEBUG
I CXXFLAGS: -I. -I./examples -O3 -std=c++11 -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -DGGML_USE_K_QUANTS -DGGML_USE_METAL
I LDFLAGS:   -framework Accelerate -framework Foundation -framework Metal -framework MetalKit
I CC:       Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX:      Apple clang version 14.0.3 (clang-1403.0.22.14.1)

make: Nothing to be done for `default'.
main: build = 972 (bfa455d)
main: seed  = 1692082979
llama.cpp: loading model from models/7B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 4096
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_head_kv  = 32
llama_model_load_internal: n_layer    = 32
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: n_gqa      = 1
llama_model_load_internal: rnorm_eps  = 5.0e-06
llama_model_load_internal: n_ff       = 11008
llama_model_load_internal: freq_base  = 10000.0
llama_model_load_internal: freq_scale = 1
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size =    0.08 MB
llama_model_load_internal: mem required  = 3949.96 MB (+  256.00 MB per state)
llama_new_context_with_model: kv self size  =  256.00 MB
ggml_metal_init: allocating
ggml_metal_init: loading '/Users/ggerganov/development/github/llama.cpp/ggml-metal.metal'
ggml_metal_init: loaded kernel_add                            0x125609700
ggml_metal_init: loaded kernel_add_row                        0x125609d20
ggml_metal_init: loaded kernel_mul                            0x12560a260
ggml_metal_init: loaded kernel_mul_row                        0x12560a8b0
ggml_metal_init: loaded kernel_scale                          0x12560adf0
ggml_metal_init: loaded kernel_silu                           0x12560b330
ggml_metal_init: loaded kernel_relu                           0x12560b870
ggml_metal_init: loaded kernel_gelu                           0x12560bdb0
ggml_metal_init: loaded kernel_soft_max                       0x12560c480
ggml_metal_init: loaded kernel_diag_mask_inf                  0x12560cb00
ggml_metal_init: loaded kernel_get_rows_f16                   0x12560d1d0
ggml_metal_init: loaded kernel_get_rows_q4_0                  0x12560da10
ggml_metal_init: loaded kernel_get_rows_q4_1                  0x12560e0e0
ggml_metal_init: loaded kernel_get_rows_q2_K                  0x12560e7b0
ggml_metal_init: loaded kernel_get_rows_q3_K                  0x12560ee80
ggml_metal_init: loaded kernel_get_rows_q4_K                  0x12560f550
ggml_metal_init: loaded kernel_get_rows_q5_K                  0x12560fc20
ggml_metal_init: loaded kernel_get_rows_q6_K                  0x1256102f0
ggml_metal_init: loaded kernel_rms_norm                       0x1256109d0
ggml_metal_init: loaded kernel_norm                           0x125611210
ggml_metal_init: loaded kernel_mul_mat_f16_f32                0x125611ae0
ggml_metal_init: loaded kernel_mul_mat_q4_0_f32               0x125612240
ggml_metal_init: loaded kernel_mul_mat_q4_1_f32               0x1256129a0
ggml_metal_init: loaded kernel_mul_mat_q2_K_f32               0x125613280
ggml_metal_init: loaded kernel_mul_mat_q3_K_f32               0x1256139e0
ggml_metal_init: loaded kernel_mul_mat_q4_K_f32               0x125614140
ggml_metal_init: loaded kernel_mul_mat_q5_K_f32               0x1256148a0
ggml_metal_init: loaded kernel_mul_mat_q6_K_f32               0x125615460
ggml_metal_init: loaded kernel_mul_mm_f16_f32                 0x125615c20
ggml_metal_init: loaded kernel_mul_mm_q4_0_f32                0x1256163e0
ggml_metal_init: loaded kernel_mul_mm_q4_1_f32                0x125616ba0
ggml_metal_init: loaded kernel_mul_mm_q2_K_f32                0x125617360
ggml_metal_init: loaded kernel_mul_mm_q3_K_f32                0x1256178a0
ggml_metal_init: loaded kernel_mul_mm_q4_K_f32                0x125618060
ggml_metal_init: loaded kernel_mul_mm_q5_K_f32                0x125618820
ggml_metal_init: loaded kernel_mul_mm_q6_K_f32                0x125618fe0
ggml_metal_init: loaded kernel_rope                           0x125619520
ggml_metal_init: loaded kernel_alibi_f32                      0x125619e00
ggml_metal_init: loaded kernel_cpy_f32_f16                    0x12561a6b0
ggml_metal_init: loaded kernel_cpy_f32_f32                    0x12561af60
ggml_metal_init: loaded kernel_cpy_f16_f16                    0x12561b810
ggml_metal_init: recommendedMaxWorkingSetSize = 147456.00 MB
ggml_metal_init: hasUnifiedMemory             = true
ggml_metal_init: maxTransferRate              = built-in GPU
llama_new_context_with_model: max tensor size =   102.54 MB
ggml_metal_add_buffer: allocated 'data            ' buffer, size =  3648.31 MB, ( 3648.75 / 147456.00)
ggml_metal_add_buffer: allocated 'eval            ' buffer, size =    10.17 MB, ( 3658.92 / 147456.00)
ggml_metal_add_buffer: allocated 'kv              ' buffer, size =   258.00 MB, ( 3916.92 / 147456.00)
ggml_metal_add_buffer: allocated 'scr0            ' buffer, size =   132.00 MB, ( 4048.92 / 147456.00)
ggml_metal_add_buffer: allocated 'scr1            ' buffer, size =   160.00 MB, ( 4208.92 / 147456.00)

system_info: n_threads = 16 / 24 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
perplexity: calculating perplexity over 655 chunks, batch_size=512
perplexity: 0.95 seconds per pass - ETA 10 minutes
[1]2822.3244,[2]17718.6767,[3]38215.2720,[4]53667.8852,^C

In comparison, with -ngl 0 we get normal values:

[1]4.4122,[2]4.9228,[3]5.8191,...

You can get wikitext test data from here:

https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip

I mixed up ne02 and nb02 in previous commit.
@lshzh-ww
Copy link
Contributor Author

@ggerganov fixed.

I accidentally mixed ne02 with nb02. Now results of perplexity test look right. Btw, could you also evaluate the llama2 70B model? This particular model is too large to inference on my laptop.

@MichaelDays
Copy link

MichaelDays commented Aug 15, 2023

Running test command line on llama-2-7B-chat q4_0, prompt eval tokens per second
main [b5ffb28]
pr [a527ecc]

Mac branch -ngl 0 -ngl 1
Mini M2 Pro, 12c CPU, 19c GPU main 63 tok/s 63 tok/s
Mini M2 Pro, 12c CPU, 19c GPU pr 63 tok/s 191 tok/s
Pro (2019) intel 16c CPU, Vega II GPU main 26 tok/s 25 tok/s*
Pro (2019) intel 16c CPU, Vega II GPU pr 26 tok/s GGML_ASSERT: ggml-metal.m:1088: false
  • Intel/Vega II -ngl 1 metal text output looks misdetokenised in main.

Was prompt processing metal accelerated at all in main prior to this patch?

@ggerganov
Copy link
Member

@lshzh-ww

Perplexity numbers are good now. Here are results on M2 Ultra using the test command that you provided:

  • 7B Q4_0- 672 t/s
  • 70B Q4_0 - 82 t/s

Can do more tests later if needed

@lshzh-ww
Copy link
Contributor Author

lshzh-ww commented Aug 15, 2023

@ggerganov I tested q4_0, q4_1 and all k-quants, and token generation looks good.

The only minor issue is that @MichaelDays reported a broken inference on Intel/AMD. I don't have an Intel Mac at hand and can't test what triggered the assert. Considering that the outputs are already problematic in the main branch for Intel/AMD, I think we can leave this for the next PR and proceed with the merge.

@MichaelDays
No, in the main branch prompt processing uses CPU.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job as always!

At some point we have to consider doing a full quantized multiplication.
Currently, we dequantize the weights and multiply with F32 activations. Instead, we want to pre-quantize the F32 activations to Q8_0 and then perform the multiplication using integer intrinsic (if available). This should significantly improve the "matrix x matrix" multiplication and might also have some positive effect on "matrix x vector"

The main question that I haven't figured out yet is how to create and store the temporary pre-quantized Q8_0 activations in an elegant way.

@ggerganov ggerganov mentioned this pull request Aug 16, 2023
34 tasks
@slaren
Copy link
Member

slaren commented Aug 16, 2023

The main question that I haven't figured out yet is how to create and store the temporary pre-quantized Q8_0 activations in an elegant way.

A possible solution would be to not do it automatically, and instead do it in the graph by manually copying the tensor to a different format. For example:

ggml_mul_mat(ctx, weight, 
        ggml_cpy(ctx, cur, ggml_new_tensor(ctx, GGML_TYPE_Q8_0, cur->n_dims, cur->ne)));

This would somewhat simplify memory management in the backends, but more importantly would give the users more control over the precision of the computations. It would also allow comparison of intermediate results between different backends.

@ggerganov
Copy link
Member

ggerganov commented Aug 16, 2023

Actually a great point! Should add this to the roadmap:

ggml-org/ggml#455

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants