@@ -183,6 +183,14 @@ def decorator(func):
183183 help = "Quantize weight symmetrically for weight only quantization. It usually brings better latency at"
184184 " the cost of accuracy. It has not effect if you are loading low-precision checkpoints." ,
185185)
186+ parser .add_argument (
187+ "--low-precision-checkpoint" ,
188+ default = "" ,
189+ type = str ,
190+ help = "Low precision checkpoint file generated by algorithms, such as GPTQ. It contains"
191+ " INT4 weights, scales, zero points, etc. For better accuracy of weight only"
192+ " quantization with INT4 weight." ,
193+ )
186194
187195args = parser .parse_args ()
188196
@@ -394,7 +402,84 @@ def write_checkpoints_json():
394402 )
395403
396404 self .model = self .model .module
405+ import pathlib
406+
407+ if args .low_precision_checkpoint != "" :
408+ pathname = args .low_precision_checkpoint
409+ assert os .path .exists (
410+ pathname
411+ ), f"Checkpoint file does not exist: { pathname } "
412+ if os .path .isfile (pathname ):
413+ low_precision_checkpoint = None
414+ if pathname .endswith (".pt" ) or pathname .endswith (".pth" ):
415+ low_precision_checkpoint = torch .load (pathname , weights_only = True )
416+ elif pathname .endswith (".safetensors" ):
417+ try :
418+ import safetensors
419+ except ImportError :
420+ print (
421+ "Please install safetensors package to load safetensors checkpoint."
422+ )
423+ exit (1 )
424+ low_precision_checkpoint = safetensors .torch .load_file (pathname )
425+ assert (
426+ low_precision_checkpoint is not None
427+ ), f"Invalid checkpoint file: { pathname } . Should be a .pt, .pth or .safetensors file."
428+
429+ quant_method = {"quant_method" : "gptq" }
430+
431+ elif os .path .isdir (pathname ):
432+ low_precision_checkpoint = {}
433+ for pattern in ["*.pt" , "*.pth" ]:
434+ files = list (pathlib .Path (pathname ).glob (pattern ))
435+ if files :
436+ for f in files :
437+ data_f = torch .load (f , weights_only = True )
438+ low_precision_checkpoint .update (data_f )
439+ break
440+ if not low_precision_checkpoint :
441+ files = list (pathlib .Path (pathname ).glob ("*.safetensors" ))
442+ if files :
443+ try :
444+ import safetensors
445+ except ImportError :
446+ print (
447+ "Please install safetensors package to load safetensors checkpoint."
448+ )
449+ exit (1 )
450+ for f in files :
451+ data_f = safetensors .torch .load_file (f )
452+ low_precision_checkpoint .update (data_f )
453+ assert (
454+ len (low_precision_checkpoint ) > 0
455+ ), f"Cannot find checkpoint (.pt/.pth/.safetensors) files in path { pathname } ."
397456
457+ try :
458+ with open (pathname + "/config.json" ) as f :
459+ quant_model_config = json .load (f )
460+ quant_method = {
461+ "quant_method" : quant_model_config ["quantization_config" ][
462+ "quant_method"
463+ ]
464+ }
465+ except Exception as e :
466+ print (
467+ "warning: loading HF config.json to get `quant_method` failed, due to " ,
468+ e ,
469+ )
470+ print ("warning: specifying `quant_method` = `gptq` by default." )
471+ quant_method = {"quant_method" : "gptq" }
472+
473+ else :
474+ raise AssertionError (
475+ f"Invalid low-precision-checkpoint: { pathname } ."
476+ " Should be a .pt/.pth/.safetensors file or a directory containing them."
477+ )
478+
479+ low_precision_checkpoint = (low_precision_checkpoint , quant_method )
480+ low_precision_checkpoint_dict = low_precision_checkpoint [0 ]
481+ else :
482+ low_precision_checkpoint = None
398483 if self ._with_ipex :
399484 ipex_woq_enabled = args .ipex_weight_only_quantization
400485 if ipex_woq_enabled :
@@ -447,13 +532,111 @@ def write_checkpoints_json():
447532 group_size = args .group_size ,
448533 weight_qscheme = weight_qscheme ,
449534 )
535+ model = self .model
536+ if low_precision_checkpoint is not None :
537+ num_heads = model .config .num_attention_heads
538+ rank = local_rank
539+
540+ layers_split_by_N = [
541+ "q_proj" ,
542+ "k_proj" ,
543+ "v_proj" ,
544+ "gate_proj" ,
545+ "up_proj" ,
546+ "fc_in" ,
547+ "fc1" ,
548+ "query_key_value" ,
549+ "w1" ,
550+ "w3" ,
551+ ]
552+ layers_split_by_K = [
553+ "o_proj" ,
554+ "down_proj" ,
555+ "fc_out" ,
556+ "fc2" ,
557+ "out_proj" ,
558+ "dense" ,
559+ "dense_4h_to_h" ,
560+ "w2" ,
561+ ]
562+ lm_head_layers = ["lm_head" ] # split by K but not quantized
563+ quantization_method = quant_model_config ["quantization_config" ][
564+ "quant_method"
565+ ]
566+ head_range = [0 ]
567+ head_per_rank = num_heads // world_size
568+
569+ for i in range (0 , world_size ):
570+ head_this_rank = head_per_rank
571+ if i < num_heads % world_size :
572+ head_this_rank += 1
573+ head_range .append (head_range [- 1 ] + head_this_rank )
574+ for key in low_precision_checkpoint [0 ].keys ():
575+ q_head_start = head_range [rank ]
576+ q_head_end = q_head_start + (
577+ head_range [rank + 1 ] - head_range [rank ]
578+ )
579+ if "bias" in key :
580+ continue
581+ if any (substring in key for substring in layers_split_by_N ):
582+ data = low_precision_checkpoint_dict [key ]
583+ if quantization_method == "awq" :
584+ # awq qweight: [K, N // 8]
585+ # awq scales: [K // G, N]
586+ # awq qzeros: [K // G, N // 8]
587+ dim = data .shape [- 1 ] // head_range [- 1 ]
588+ low_precision_checkpoint_dict [key ] = data [
589+ :, q_head_start * dim : q_head_end * dim
590+ ]
591+ else :
592+ raise AssertionError (
593+ f"{ quantization_method } is not supported yet."
594+ )
595+ if any (substring in key for substring in layers_split_by_K ):
596+ data = low_precision_checkpoint_dict [key ]
597+ if quantization_method == "awq" :
598+ # awq qweight: [K, N // 8]
599+ # awq scales: [K // G, N]
600+ # awq qzeros: [K // G, N // 8]
601+ if data .shape [0 ] % head_range [- 1 ] == 0 :
602+ dim = data .shape [0 ] // head_range [- 1 ]
603+ else :
604+ assert data .shape [0 ] % world_size == 0
605+ dim = data .shape [0 ] // world_size
606+ q_head_start = local_rank
607+ q_head_end = local_rank + 1
608+ low_precision_checkpoint_dict [key ] = data [
609+ q_head_start * dim : q_head_end * dim
610+ ]
611+ else :
612+ raise AssertionError (
613+ f"{ quantization_method } is not supported yet."
614+ )
615+ if any (substring in key for substring in lm_head_layers ):
616+ # lm_head: [N, K] (not quantized)
617+ # Same for both AWQ and GPTQ
618+ data = low_precision_checkpoint_dict [key ]
619+ if data .shape [- 1 ] % head_range [- 1 ] == 0 :
620+ dim = data .shape [- 1 ] // head_range [- 1 ]
621+ else :
622+ dim = data .shape [- 1 ] // world_size
623+ q_head_start = local_rank
624+ q_head_end = local_rank + 1
625+ low_precision_checkpoint_dict [key ] = data [
626+ :, q_head_start * dim : q_head_end * dim
627+ ]
628+ low_precision_dict = (low_precision_checkpoint_dict , quant_method )
629+ else :
630+ low_precision_dict = None
631+
450632 self .model = ipex .llm .optimize (
451633 self .model .eval (),
452634 dtype = infer_dtype ,
453635 quantization_config = qconfig if ipex_woq_enabled else None ,
454636 inplace = True ,
455637 deployment_mode = False ,
456638 cache_weight_for_large_batch = args .cache_weight_for_large_batch ,
639+ low_precision_checkpoint = low_precision_dict ,
457640 )
458641
459642 self .base_model = self .model
0 commit comments