Skip to content

Commit 016f9bb

Browse files
committed
metal : fix ggml_get_rows to work with non-cont src1
1 parent 0710b0f commit 016f9bb

File tree

2 files changed

+54
-30
lines changed

2 files changed

+54
-30
lines changed

‎ggml-metal.m‎

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,11 +1584,12 @@ void ggml_metal_graph_compute(
15841584
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
15851585
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
15861586
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1587-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7];
1587+
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1588+
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1589+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1590+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
15881591

1589-
const int64_t n = ggml_nelements(src1);
1590-
1591-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1592+
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
15921593
} break;
15931594
case GGML_OP_RMS_NORM:
15941595
{

‎ggml-metal.metal‎

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3219,69 +3219,89 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
32193219
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
32203220
kernel void kernel_get_rows(
32213221
device const void * src0,
3222-
device const int * src1,
3222+
device const char * src1,
32233223
device float * dst,
32243224
constant int64_t & ne00,
32253225
constant uint64_t & nb01,
32263226
constant uint64_t & nb02,
32273227
constant int64_t & ne10,
3228+
constant uint64_t & nb10,
3229+
constant uint64_t & nb11,
32283230
constant uint64_t & nb1,
3229-
uint tgpig[[threadgroup_position_in_grid]],
3231+
constant uint64_t & nb2,
3232+
uint3 tgpig[[threadgroup_position_in_grid]],
32303233
uint tiitg[[thread_index_in_threadgroup]],
3231-
uint tptg [[threads_per_threadgroup]]) {
3232-
const int64_t i = tgpig;
3233-
const int64_t r = ((device int32_t *) src1)[i];
3234+
uint3 tptg [[threads_per_threadgroup]]) {
3235+
//const int64_t i = tgpig;
3236+
//const int64_t r = ((device int32_t *) src1)[i];
3237+
3238+
const int64_t i10 = tgpig.x;
3239+
const int64_t i11 = tgpig.y;
32343240

3235-
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) {
3241+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3242+
3243+
const int64_t i02 = i11;
3244+
3245+
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
32363246
float4x4 temp;
32373247
dequantize_func(
3238-
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
3239-
*(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
3248+
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
3249+
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
32403250
}
32413251
}
32423252

32433253
kernel void kernel_get_rows_f32(
32443254
device const void * src0,
3245-
device const int * src1,
3255+
device const char * src1,
32463256
device float * dst,
32473257
constant int64_t & ne00,
32483258
constant uint64_t & nb01,
32493259
constant uint64_t & nb02,
32503260
constant int64_t & ne10,
3261+
constant uint64_t & nb10,
3262+
constant uint64_t & nb11,
32513263
constant uint64_t & nb1,
3252-
uint tgpig[[threadgroup_position_in_grid]],
3264+
constant uint64_t & nb2,
3265+
uint3 tgpig[[threadgroup_position_in_grid]],
32533266
uint tiitg[[thread_index_in_threadgroup]],
3254-
uint tptg [[threads_per_threadgroup]]) {
3255-
const int64_t i = tgpig;
3256-
const int64_t r = ((device int32_t *) src1)[i];
3267+
uint3 tptg [[threads_per_threadgroup]]) {
3268+
const int64_t i10 = tgpig.x;
3269+
const int64_t i11 = tgpig.y;
3270+
3271+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
32573272

3258-
const int64_t i02 = i/ne10;
3273+
const int64_t i02 = i11;
32593274

3260-
for (int ind = tiitg; ind < ne00; ind += tptg) {
3261-
((device float *) ((device char *) dst + i*nb1))[ind] =
3275+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3276+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
32623277
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
32633278
}
32643279
}
32653280

32663281
kernel void kernel_get_rows_f16(
32673282
device const void * src0,
3268-
device const int * src1,
3283+
device const char * src1,
32693284
device float * dst,
32703285
constant int64_t & ne00,
32713286
constant uint64_t & nb01,
32723287
constant uint64_t & nb02,
32733288
constant int64_t & ne10,
3289+
constant uint64_t & nb10,
3290+
constant uint64_t & nb11,
32743291
constant uint64_t & nb1,
3275-
uint tgpig[[threadgroup_position_in_grid]],
3292+
constant uint64_t & nb2,
3293+
uint3 tgpig[[threadgroup_position_in_grid]],
32763294
uint tiitg[[thread_index_in_threadgroup]],
3277-
uint tptg [[threads_per_threadgroup]]) {
3278-
const int64_t i = tgpig;
3279-
const int64_t r = ((device int32_t *) src1)[i];
3295+
uint3 tptg [[threads_per_threadgroup]]) {
3296+
const int64_t i10 = tgpig.x;
3297+
const int64_t i11 = tgpig.y;
32803298

3281-
const int64_t i02 = i/ne10;
3299+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
32823300

3283-
for (int ind = tiitg; ind < ne00; ind += tptg) {
3284-
((device float *) ((device char *) dst + i*nb1))[ind] =
3301+
const int64_t i02 = i11;
3302+
3303+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3304+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
32853305
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
32863306
}
32873307
}
@@ -3543,14 +3563,17 @@ kernel void kernel_mul_mm_id(
35433563

35443564
typedef void (get_rows_t)(
35453565
device const void * src0,
3546-
device const int * src1,
3566+
device const char * src1,
35473567
device float * dst,
35483568
constant int64_t & ne00,
35493569
constant uint64_t & nb01,
35503570
constant uint64_t & nb02,
35513571
constant int64_t & ne10,
3572+
constant uint64_t & nb10,
3573+
constant uint64_t & nb11,
35523574
constant uint64_t & nb1,
3553-
uint, uint, uint);
3575+
constant uint64_t & nb2,
3576+
uint3, uint, uint3);
35543577

35553578
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
35563579
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;

0 commit comments

Comments
 (0)