|
3 | 3 | from ..vram.layers import enable_vram_management |
4 | 4 | from .file import load_state_dict |
5 | 5 | import torch |
| 6 | +from contextlib import contextmanager |
| 7 | +from transformers.integrations import is_deepspeed_zero3_enabled |
| 8 | +from transformers.utils import ContextManagers |
6 | 9 |
|
7 | 10 |
|
8 | 11 | def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None): |
9 | 12 | config = {} if config is None else config |
10 | | - # Why do we use `skip_model_initialization`? |
11 | | - # It skips the random initialization of model parameters, |
12 | | - # thereby speeding up model loading and avoiding excessive memory usage. |
13 | | - with skip_model_initialization(): |
| 13 | + with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)): |
14 | 14 | model = model_class(**config) |
15 | 15 | # What is `module_map`? |
16 | 16 | # This is a module mapping table for VRAM management. |
@@ -48,7 +48,14 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic |
48 | 48 | state_dict = state_dict_converter(state_dict) |
49 | 49 | else: |
50 | 50 | state_dict = {i: state_dict[i] for i in state_dict} |
51 | | - model.load_state_dict(state_dict, assign=True) |
| 51 | + # Why does DeepSpeed ZeRO Stage 3 need to be handled separately? |
| 52 | + # Because at this stage, model parameters are partitioned across multiple GPUs. |
| 53 | + # Loading them directly could lead to excessive GPU memory consumption. |
| 54 | + if is_deepspeed_zero3_enabled(): |
| 55 | + from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model |
| 56 | + _load_state_dict_into_zero3_model(model, state_dict) |
| 57 | + else: |
| 58 | + model.load_state_dict(state_dict, assign=True) |
52 | 59 | # Why do we call `to()`? |
53 | 60 | # Because some models override the behavior of `to()`, |
54 | 61 | # especially those from libraries like Transformers. |
@@ -79,3 +86,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor |
79 | 86 | } |
80 | 87 | enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80) |
81 | 88 | return model |
| 89 | + |
| 90 | + |
| 91 | +def get_init_context(torch_dtype, device): |
| 92 | + if is_deepspeed_zero3_enabled(): |
| 93 | + from transformers.modeling_utils import set_zero3_state |
| 94 | + import deepspeed |
| 95 | + # Why do we use "deepspeed.zero.Init"? |
| 96 | + # Weight segmentation of the model can be performed on the CPU side |
| 97 | + # and loading the segmented weights onto the computing card |
| 98 | + init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()] |
| 99 | + else: |
| 100 | + # Why do we use `skip_model_initialization`? |
| 101 | + # It skips the random initialization of model parameters, |
| 102 | + # thereby speeding up model loading and avoiding excessive memory usage. |
| 103 | + init_contexts = [skip_model_initialization()] |
| 104 | + |
| 105 | + return init_contexts |
0 commit comments