@@ -3219,69 +3219,89 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
32193219template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread float4x4 &)>
32203220kernel void kernel_get_rows (
32213221 device const void * src0,
3222- device const int * src1,
3222+ device const char * src1,
32233223 device float * dst,
32243224 constant int64_t & ne00,
32253225 constant uint64_t & nb01,
32263226 constant uint64_t & nb02,
32273227 constant int64_t & ne10,
3228+ constant uint64_t & nb10,
3229+ constant uint64_t & nb11,
32283230 constant uint64_t & nb1,
3229- uint tgpig[[threadgroup_position_in_grid]],
3231+ constant uint64_t & nb2,
3232+ uint3 tgpig[[threadgroup_position_in_grid]],
32303233 uint tiitg[[thread_index_in_threadgroup]],
3231- uint tptg [[threads_per_threadgroup]]) {
3232- const int64_t i = tgpig;
3233- const int64_t r = ((device int32_t *) src1)[i];
3234+ uint3 tptg [[threads_per_threadgroup]]) {
3235+ // const int64_t i = tgpig;
3236+ // const int64_t r = ((device int32_t *) src1)[i];
3237+
3238+ const int64_t i10 = tgpig.x ;
3239+ const int64_t i11 = tgpig.y ;
32343240
3235- for (int64_t ind = tiitg; ind < ne00/16 ; ind += tptg) {
3241+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0 ];
3242+
3243+ const int64_t i02 = i11;
3244+
3245+ for (int64_t ind = tiitg; ind < ne00/16 ; ind += tptg.x ) {
32363246 float4x4 temp;
32373247 dequantize_func (
3238- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
3239- *(((device float4x4 *) ((device char *) dst + i *nb1)) + ind) = temp;
3248+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02 )) + ind/nl, ind%nl, temp);
3249+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10 *nb1)) + ind) = temp;
32403250 }
32413251}
32423252
32433253kernel void kernel_get_rows_f32 (
32443254 device const void * src0,
3245- device const int * src1,
3255+ device const char * src1,
32463256 device float * dst,
32473257 constant int64_t & ne00,
32483258 constant uint64_t & nb01,
32493259 constant uint64_t & nb02,
32503260 constant int64_t & ne10,
3261+ constant uint64_t & nb10,
3262+ constant uint64_t & nb11,
32513263 constant uint64_t & nb1,
3252- uint tgpig[[threadgroup_position_in_grid]],
3264+ constant uint64_t & nb2,
3265+ uint3 tgpig[[threadgroup_position_in_grid]],
32533266 uint tiitg[[thread_index_in_threadgroup]],
3254- uint tptg [[threads_per_threadgroup]]) {
3255- const int64_t i = tgpig;
3256- const int64_t r = ((device int32_t *) src1)[i];
3267+ uint3 tptg [[threads_per_threadgroup]]) {
3268+ const int64_t i10 = tgpig.x ;
3269+ const int64_t i11 = tgpig.y ;
3270+
3271+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0 ];
32573272
3258- const int64_t i02 = i/ne10 ;
3273+ const int64_t i02 = i11 ;
32593274
3260- for (int ind = tiitg; ind < ne00; ind += tptg) {
3261- ((device float *) ((device char *) dst + i *nb1))[ind] =
3275+ for (int ind = tiitg; ind < ne00; ind += tptg. x ) {
3276+ ((device float *) ((device char *) dst + i11*nb2 + i10 *nb1))[ind] =
32623277 ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
32633278 }
32643279}
32653280
32663281kernel void kernel_get_rows_f16 (
32673282 device const void * src0,
3268- device const int * src1,
3283+ device const char * src1,
32693284 device float * dst,
32703285 constant int64_t & ne00,
32713286 constant uint64_t & nb01,
32723287 constant uint64_t & nb02,
32733288 constant int64_t & ne10,
3289+ constant uint64_t & nb10,
3290+ constant uint64_t & nb11,
32743291 constant uint64_t & nb1,
3275- uint tgpig[[threadgroup_position_in_grid]],
3292+ constant uint64_t & nb2,
3293+ uint3 tgpig[[threadgroup_position_in_grid]],
32763294 uint tiitg[[thread_index_in_threadgroup]],
3277- uint tptg [[threads_per_threadgroup]]) {
3278- const int64_t i = tgpig;
3279- const int64_t r = ((device int32_t *) src1)[i] ;
3295+ uint3 tptg [[threads_per_threadgroup]]) {
3296+ const int64_t i10 = tgpig. x ;
3297+ const int64_t i11 = tgpig. y ;
32803298
3281- const int64_t i02 = i/ne10 ;
3299+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[ 0 ] ;
32823300
3283- for (int ind = tiitg; ind < ne00; ind += tptg) {
3284- ((device float *) ((device char *) dst + i*nb1))[ind] =
3301+ const int64_t i02 = i11;
3302+
3303+ for (int ind = tiitg; ind < ne00; ind += tptg.x ) {
3304+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
32853305 ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
32863306 }
32873307}
@@ -3543,14 +3563,17 @@ kernel void kernel_mul_mm_id(
35433563
35443564typedef void (get_rows_t )(
35453565 device const void * src0,
3546- device const int * src1,
3566+ device const char * src1,
35473567 device float * dst,
35483568 constant int64_t & ne00,
35493569 constant uint64_t & nb01,
35503570 constant uint64_t & nb02,
35513571 constant int64_t & ne10,
3572+ constant uint64_t & nb10,
3573+ constant uint64_t & nb11,
35523574 constant uint64_t & nb1,
3553- uint, uint, uint);
3575+ constant uint64_t & nb2,
3576+ uint3, uint, uint3);
35543577
35553578// template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
35563579// template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
0 commit comments