Skip to content

Commit ae09c58

Browse files
support deepspeed tp load int4 checkpoint (#3328)
* tp int4 checkpoint * small changes. * update * update lm_head * unnecessary change * some change on tp * tp update * simplify the code. * modify run_accuracy_with_deepspeed. * modify according to comment. * Support low precision checkpoint with TP in llm.optimize * Revert some changes in llm.optimize * fix bug for gpt-j. * remove unnecessary change. * remove unnecessary change. * fix bug. * support mixtral. * support mixtral. * flake8 format. * fix bug. --------- Co-authored-by: Tao, Ran <[email protected]> Co-authored-by: Xia, Weiwen <[email protected]>
1 parent 4679764 commit ae09c58

File tree

7 files changed

+494
-64
lines changed

7 files changed

+494
-64
lines changed

‎examples/cpu/llm/inference/distributed/run_accuracy_with_deepspeed.py‎

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,14 @@ def decorator(func):
183183
help="Quantize weight symmetrically for weight only quantization. It usually brings better latency at"
184184
" the cost of accuracy. It has not effect if you are loading low-precision checkpoints.",
185185
)
186+
parser.add_argument(
187+
"--low-precision-checkpoint",
188+
default="",
189+
type=str,
190+
help="Low precision checkpoint file generated by algorithms, such as GPTQ. It contains"
191+
" INT4 weights, scales, zero points, etc. For better accuracy of weight only"
192+
" quantization with INT4 weight.",
193+
)
186194

187195
args = parser.parse_args()
188196

@@ -394,7 +402,84 @@ def write_checkpoints_json():
394402
)
395403

396404
self.model = self.model.module
405+
import pathlib
406+
407+
if args.low_precision_checkpoint != "":
408+
pathname = args.low_precision_checkpoint
409+
assert os.path.exists(
410+
pathname
411+
), f"Checkpoint file does not exist: {pathname}"
412+
if os.path.isfile(pathname):
413+
low_precision_checkpoint = None
414+
if pathname.endswith(".pt") or pathname.endswith(".pth"):
415+
low_precision_checkpoint = torch.load(pathname, weights_only=True)
416+
elif pathname.endswith(".safetensors"):
417+
try:
418+
import safetensors
419+
except ImportError:
420+
print(
421+
"Please install safetensors package to load safetensors checkpoint."
422+
)
423+
exit(1)
424+
low_precision_checkpoint = safetensors.torch.load_file(pathname)
425+
assert (
426+
low_precision_checkpoint is not None
427+
), f"Invalid checkpoint file: {pathname}. Should be a .pt, .pth or .safetensors file."
428+
429+
quant_method = {"quant_method": "gptq"}
430+
431+
elif os.path.isdir(pathname):
432+
low_precision_checkpoint = {}
433+
for pattern in ["*.pt", "*.pth"]:
434+
files = list(pathlib.Path(pathname).glob(pattern))
435+
if files:
436+
for f in files:
437+
data_f = torch.load(f, weights_only=True)
438+
low_precision_checkpoint.update(data_f)
439+
break
440+
if not low_precision_checkpoint:
441+
files = list(pathlib.Path(pathname).glob("*.safetensors"))
442+
if files:
443+
try:
444+
import safetensors
445+
except ImportError:
446+
print(
447+
"Please install safetensors package to load safetensors checkpoint."
448+
)
449+
exit(1)
450+
for f in files:
451+
data_f = safetensors.torch.load_file(f)
452+
low_precision_checkpoint.update(data_f)
453+
assert (
454+
len(low_precision_checkpoint) > 0
455+
), f"Cannot find checkpoint (.pt/.pth/.safetensors) files in path {pathname}."
397456

457+
try:
458+
with open(pathname + "/config.json") as f:
459+
quant_model_config = json.load(f)
460+
quant_method = {
461+
"quant_method": quant_model_config["quantization_config"][
462+
"quant_method"
463+
]
464+
}
465+
except Exception as e:
466+
print(
467+
"warning: loading HF config.json to get `quant_method` failed, due to ",
468+
e,
469+
)
470+
print("warning: specifying `quant_method` = `gptq` by default.")
471+
quant_method = {"quant_method": "gptq"}
472+
473+
else:
474+
raise AssertionError(
475+
f"Invalid low-precision-checkpoint: {pathname}."
476+
" Should be a .pt/.pth/.safetensors file or a directory containing them."
477+
)
478+
479+
low_precision_checkpoint = (low_precision_checkpoint, quant_method)
480+
low_precision_checkpoint_dict = low_precision_checkpoint[0]
481+
else:
482+
low_precision_checkpoint = None
398483
if self._with_ipex:
399484
ipex_woq_enabled = args.ipex_weight_only_quantization
400485
if ipex_woq_enabled:
@@ -447,13 +532,111 @@ def write_checkpoints_json():
447532
group_size=args.group_size,
448533
weight_qscheme=weight_qscheme,
449534
)
535+
model = self.model
536+
if low_precision_checkpoint is not None:
537+
num_heads = model.config.num_attention_heads
538+
rank = local_rank
539+
540+
layers_split_by_N = [
541+
"q_proj",
542+
"k_proj",
543+
"v_proj",
544+
"gate_proj",
545+
"up_proj",
546+
"fc_in",
547+
"fc1",
548+
"query_key_value",
549+
"w1",
550+
"w3",
551+
]
552+
layers_split_by_K = [
553+
"o_proj",
554+
"down_proj",
555+
"fc_out",
556+
"fc2",
557+
"out_proj",
558+
"dense",
559+
"dense_4h_to_h",
560+
"w2",
561+
]
562+
lm_head_layers = ["lm_head"] # split by K but not quantized
563+
quantization_method = quant_model_config["quantization_config"][
564+
"quant_method"
565+
]
566+
head_range = [0]
567+
head_per_rank = num_heads // world_size
568+
569+
for i in range(0, world_size):
570+
head_this_rank = head_per_rank
571+
if i < num_heads % world_size:
572+
head_this_rank += 1
573+
head_range.append(head_range[-1] + head_this_rank)
574+
for key in low_precision_checkpoint[0].keys():
575+
q_head_start = head_range[rank]
576+
q_head_end = q_head_start + (
577+
head_range[rank + 1] - head_range[rank]
578+
)
579+
if "bias" in key:
580+
continue
581+
if any(substring in key for substring in layers_split_by_N):
582+
data = low_precision_checkpoint_dict[key]
583+
if quantization_method == "awq":
584+
# awq qweight: [K, N // 8]
585+
# awq scales: [K // G, N]
586+
# awq qzeros: [K // G, N // 8]
587+
dim = data.shape[-1] // head_range[-1]
588+
low_precision_checkpoint_dict[key] = data[
589+
:, q_head_start * dim : q_head_end * dim
590+
]
591+
else:
592+
raise AssertionError(
593+
f"{quantization_method} is not supported yet."
594+
)
595+
if any(substring in key for substring in layers_split_by_K):
596+
data = low_precision_checkpoint_dict[key]
597+
if quantization_method == "awq":
598+
# awq qweight: [K, N // 8]
599+
# awq scales: [K // G, N]
600+
# awq qzeros: [K // G, N // 8]
601+
if data.shape[0] % head_range[-1] == 0:
602+
dim = data.shape[0] // head_range[-1]
603+
else:
604+
assert data.shape[0] % world_size == 0
605+
dim = data.shape[0] // world_size
606+
q_head_start = local_rank
607+
q_head_end = local_rank + 1
608+
low_precision_checkpoint_dict[key] = data[
609+
q_head_start * dim : q_head_end * dim
610+
]
611+
else:
612+
raise AssertionError(
613+
f"{quantization_method} is not supported yet."
614+
)
615+
if any(substring in key for substring in lm_head_layers):
616+
# lm_head: [N, K] (not quantized)
617+
# Same for both AWQ and GPTQ
618+
data = low_precision_checkpoint_dict[key]
619+
if data.shape[-1] % head_range[-1] == 0:
620+
dim = data.shape[-1] // head_range[-1]
621+
else:
622+
dim = data.shape[-1] // world_size
623+
q_head_start = local_rank
624+
q_head_end = local_rank + 1
625+
low_precision_checkpoint_dict[key] = data[
626+
:, q_head_start * dim : q_head_end * dim
627+
]
628+
low_precision_dict = (low_precision_checkpoint_dict, quant_method)
629+
else:
630+
low_precision_dict = None
631+
450632
self.model = ipex.llm.optimize(
451633
self.model.eval(),
452634
dtype=infer_dtype,
453635
quantization_config=qconfig if ipex_woq_enabled else None,
454636
inplace=True,
455637
deployment_mode=False,
456638
cache_weight_for_large_batch=args.cache_weight_for_large_batch,
639+
low_precision_checkpoint=low_precision_dict,
457640
)
458641

459642
self.base_model = self.model

0 commit comments

Comments
 (0)