@@ -223,6 +223,27 @@ def get_int_from_env(env_keys, default):
223223
224224TokenSequence = Union [List [int ], torch .LongTensor , torch .Tensor , BatchEncoding ]
225225
226+ tp_grain_size = 64
227+ if args .ipex_weight_only_quantization and args .low_precision_checkpoint != "" :
228+ pathname = args .low_precision_checkpoint
229+ assert os .path .exists (pathname ), f"Checkpoint file does not exist: { pathname } "
230+ if os .path .isdir (pathname ):
231+ try :
232+ with open (pathname + "/config.json" ) as f :
233+ quant_model_config = json .load (f )
234+ tp_grain_size = int (
235+ quant_model_config ["quantization_config" ]["group_size" ]
236+ )
237+ except Exception as e :
238+ print ("Failed to get group_size from config.json" )
239+ elif args .group_size > 0 :
240+ tp_grain_size = args .group_size
241+ else :
242+ print (
243+ "Warning: cannot get group_size from config.json or --group-size, "
244+ "using default value 64 for tp_grain_size"
245+ )
246+
226247
227248class HuggingFaceModel (BaseLM ):
228249 _DEFAULT_MAX_LENGTH = 2048
@@ -399,6 +420,9 @@ def write_checkpoints_json():
399420 base_dir = repo_root ,
400421 dtype = infer_dtype ,
401422 checkpoint = checkpoints_json ,
423+ tensor_parallel = deepspeed .inference .config .DeepSpeedTPConfig (
424+ tp_grain_size = tp_grain_size
425+ ),
402426 )
403427
404428 self .model = self .model .module
@@ -537,10 +561,13 @@ def write_checkpoints_json():
537561 num_heads = model .config .num_attention_heads
538562 rank = local_rank
539563
540- layers_split_by_N = [
564+ mha_layers_split_by_N = [
541565 "q_proj" ,
542566 "k_proj" ,
543567 "v_proj" ,
568+ ]
569+ # mlp is split with grain size = tp_grain_size
570+ mlp_layers_split_by_N = [
544571 "gate_proj" ,
545572 "up_proj" ,
546573 "fc_in" ,
@@ -549,23 +576,26 @@ def write_checkpoints_json():
549576 "w1" ,
550577 "w3" ,
551578 ]
552- layers_split_by_K = [
579+ mha_layers_split_by_K = [
553580 "o_proj" ,
581+ "out_proj" ,
582+ ]
583+ # mlp is split with grain size = tp_grain_size
584+ mlp_layers_split_by_K = [
554585 "down_proj" ,
555586 "fc_out" ,
556587 "fc2" ,
557- "out_proj" ,
558588 "dense" ,
559589 "dense_4h_to_h" ,
560590 "w2" ,
561591 ]
592+ # lm_head is split with grain size = tp_grain_size
562593 lm_head_layers = ["lm_head" ] # split by K but not quantized
563594 quantization_method = quant_model_config ["quantization_config" ][
564595 "quant_method"
565596 ]
566597 head_range = [0 ]
567598 head_per_rank = num_heads // world_size
568-
569599 for i in range (0 , world_size ):
570600 head_this_rank = head_per_rank
571601 if i < num_heads % world_size :
@@ -578,7 +608,7 @@ def write_checkpoints_json():
578608 )
579609 if "bias" in key :
580610 continue
581- if any (substring in key for substring in layers_split_by_N ):
611+ if any (substring in key for substring in mha_layers_split_by_N ):
582612 data = low_precision_checkpoint_dict [key ]
583613 if quantization_method == "awq" :
584614 # awq qweight: [K, N // 8]
@@ -592,7 +622,48 @@ def write_checkpoints_json():
592622 raise AssertionError (
593623 f"{ quantization_method } is not supported yet."
594624 )
595- if any (substring in key for substring in layers_split_by_K ):
625+ elif any (
626+ substring in key for substring in mlp_layers_split_by_N
627+ ):
628+ data = low_precision_checkpoint_dict [key ]
629+ if quantization_method == "awq" :
630+ # awq qweight: [K, N // 8]
631+ # awq scales: [K // G, N]
632+ # awq qzeros: [K // G, N // 8]
633+ if "scales" in key :
634+ assert (
635+ data .shape [1 ] % tp_grain_size == 0
636+ ), "N must be divisible by tp_grain_size"
637+ grains = data .shape [1 ] // tp_grain_size
638+ dim = tp_grain_size
639+ else :
640+ assert (
641+ data .shape [1 ] * 8
642+ ) % tp_grain_size == 0 , (
643+ "N must be divisible by tp_grain_size"
644+ )
645+ grains = data .shape [1 ] // (tp_grain_size // 8 )
646+ dim = tp_grain_size // 8
647+ grains_per_rank = grains // world_size
648+ grains_rem = grains % world_size
649+ grains_start = grains_per_rank * local_rank + min (
650+ local_rank , grains_rem
651+ )
652+ grains_end = (
653+ grains_start
654+ + grains_per_rank
655+ + (1 if local_rank < grains_rem else 0 )
656+ )
657+ low_precision_checkpoint_dict [key ] = data [
658+ :, grains_start * dim : grains_end * dim
659+ ]
660+ else :
661+ raise AssertionError (
662+ f"{ quantization_method } is not supported yet."
663+ )
664+ elif any (
665+ substring in key for substring in mha_layers_split_by_K
666+ ):
596667 data = low_precision_checkpoint_dict [key ]
597668 if quantization_method == "awq" :
598669 # awq qweight: [K, N // 8]
@@ -612,18 +683,61 @@ def write_checkpoints_json():
612683 raise AssertionError (
613684 f"{ quantization_method } is not supported yet."
614685 )
615- if any (substring in key for substring in lm_head_layers ):
686+ elif any (
687+ substring in key for substring in mlp_layers_split_by_K
688+ ):
689+ data = low_precision_checkpoint_dict [key ]
690+ if quantization_method == "awq" :
691+ # awq qweight: [K, N // 8]
692+ # awq scales: [K // G, N]
693+ # awq qzeros: [K // G, N // 8]
694+ if "qweight" in key :
695+ assert (
696+ data .shape [0 ] % tp_grain_size == 0
697+ ), "K must be divisible by tp_grain_size"
698+ grains = data .shape [0 ] // tp_grain_size
699+ dim = tp_grain_size
700+ else :
701+ grains = data .shape [0 ]
702+ dim = 1
703+ grains_per_rank = grains // world_size
704+ grains_rem = grains % world_size
705+ grains_start = grains_per_rank * local_rank + min (
706+ local_rank , grains_rem
707+ )
708+ grains_end = (
709+ grains_start
710+ + grains_per_rank
711+ + (1 if local_rank < grains_rem else 0 )
712+ )
713+ low_precision_checkpoint_dict [key ] = data [
714+ grains_start * dim : grains_end * dim
715+ ]
716+ else :
717+ raise AssertionError (
718+ f"{ quantization_method } is not supported yet."
719+ )
720+ elif any (substring in key for substring in lm_head_layers ):
616721 # lm_head: [N, K] (not quantized)
617722 # Same for both AWQ and GPTQ
618723 data = low_precision_checkpoint_dict [key ]
619- if data .shape [- 1 ] % head_range [- 1 ] == 0 :
620- dim = data .shape [- 1 ] // head_range [- 1 ]
621- else :
622- dim = data .shape [- 1 ] // world_size
623- q_head_start = local_rank
624- q_head_end = local_rank + 1
724+ assert (
725+ data .shape [1 ] % tp_grain_size == 0
726+ ), "K must be divisible by tp_grain_size"
727+ grains = data .shape [1 ] // tp_grain_size
728+ dim = tp_grain_size
729+ grains_per_rank = grains // world_size
730+ grains_rem = grains % world_size
731+ grains_start = grains_per_rank * local_rank + min (
732+ local_rank , grains_rem
733+ )
734+ grains_end = (
735+ grains_start
736+ + grains_per_rank
737+ + (1 if local_rank < grains_rem else 0 )
738+ )
625739 low_precision_checkpoint_dict [key ] = data [
626- :, q_head_start * dim : q_head_end * dim
740+ :, grains_start * dim : grains_end * dim
627741 ]
628742 low_precision_checkpoint = (
629743 low_precision_checkpoint_dict ,
@@ -1381,6 +1495,9 @@ def write_checkpoints_json():
13811495 base_dir = repo_root ,
13821496 dtype = infer_dtype ,
13831497 checkpoint = checkpoints_json ,
1498+ tensor_parallel = deepspeed .inference .config .DeepSpeedTPConfig (
1499+ tp_grain_size = tp_grain_size
1500+ ),
13841501 )
13851502
13861503 self ._model = self ._model .module
@@ -2146,6 +2263,9 @@ def write_checkpoints_json():
21462263 base_dir = repo_root ,
21472264 dtype = infer_dtype ,
21482265 checkpoint = checkpoints_json ,
2266+ tensor_parallel = deepspeed .inference .config .DeepSpeedTPConfig (
2267+ tp_grain_size = tp_grain_size
2268+ ),
21492269 )
21502270
21512271 self .model = self .model .module
0 commit comments