Skip to content

Commit 1b47e1d

Browse files
authored
Merge pull request #1272 from modelscope/zero3-fix
Support DeepSpeed ZeRO 3
2 parents abdf66d + b0bf78e commit 1b47e1d

File tree

26 files changed

+353
-188
lines changed

26 files changed

+353
-188
lines changed

‎diffsynth/core/loader/model.py‎

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from ..vram.layers import enable_vram_management
44
from .file import load_state_dict
55
import torch
6+
from contextlib import contextmanager
7+
from transformers.integrations import is_deepspeed_zero3_enabled
8+
from transformers.utils import ContextManagers
69

710

811
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
912
config = {} if config is None else config
10-
# Why do we use `skip_model_initialization`?
11-
# It skips the random initialization of model parameters,
12-
# thereby speeding up model loading and avoiding excessive memory usage.
13-
with skip_model_initialization():
13+
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
1414
model = model_class(**config)
1515
# What is `module_map`?
1616
# This is a module mapping table for VRAM management.
@@ -48,7 +48,14 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
4848
state_dict = state_dict_converter(state_dict)
4949
else:
5050
state_dict = {i: state_dict[i] for i in state_dict}
51-
model.load_state_dict(state_dict, assign=True)
51+
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
52+
# Because at this stage, model parameters are partitioned across multiple GPUs.
53+
# Loading them directly could lead to excessive GPU memory consumption.
54+
if is_deepspeed_zero3_enabled():
55+
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
56+
_load_state_dict_into_zero3_model(model, state_dict)
57+
else:
58+
model.load_state_dict(state_dict, assign=True)
5259
# Why do we call `to()`?
5360
# Because some models override the behavior of `to()`,
5461
# especially those from libraries like Transformers.
@@ -79,3 +86,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor
7986
}
8087
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
8188
return model
89+
90+
91+
def get_init_context(torch_dtype, device):
92+
if is_deepspeed_zero3_enabled():
93+
from transformers.modeling_utils import set_zero3_state
94+
import deepspeed
95+
# Why do we use "deepspeed.zero.Init"?
96+
# Weight segmentation of the model can be performed on the CPU side
97+
# and loading the segmented weights onto the computing card
98+
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
99+
else:
100+
# Why do we use `skip_model_initialization`?
101+
# It skips the random initialization of model parameters,
102+
# thereby speeding up model loading and avoiding excessive memory usage.
103+
init_contexts = [skip_model_initialization()]
104+
105+
return init_contexts

‎diffsynth/diffusion/logger.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_ste
1818

1919
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
2020
accelerator.wait_for_everyone()
21+
state_dict = accelerator.get_state_dict(model)
2122
if accelerator.is_main_process:
22-
state_dict = accelerator.get_state_dict(model)
2323
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
2424
state_dict = self.state_dict_converter(state_dict)
2525
os.makedirs(self.output_path, exist_ok=True)
@@ -34,8 +34,8 @@ def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save
3434

3535
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
3636
accelerator.wait_for_everyone()
37+
state_dict = accelerator.get_state_dict(model)
3738
if accelerator.is_main_process:
38-
state_dict = accelerator.get_state_dict(model)
3939
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
4040
state_dict = self.state_dict_converter(state_dict)
4141
os.makedirs(self.output_path, exist_ok=True)

‎diffsynth/diffusion/runner.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def launch_training_task(
2727
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
2828
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
2929
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
30-
30+
model.to(device=accelerator.device)
3131
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
3232

3333
for epoch_id in range(num_epochs):
@@ -59,6 +59,7 @@ def launch_data_process_task(
5959
num_workers = args.dataset_num_workers
6060

6161
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
62+
model.to(device=accelerator.device)
6263
model, dataloader = accelerator.prepare(model, dataloader)
6364

6465
for data_id, data in enumerate(tqdm(dataloader)):

‎diffsynth/models/wan_video_dit.py‎

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Tuple, Optional
66
from einops import rearrange
77
from .wan_video_camera_controller import SimpleAdapter
8+
from ..core.gradient import gradient_checkpoint_forward
89

910
try:
1011
import flash_attn_interface
@@ -379,27 +380,15 @@ def forward(self,
379380
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
380381
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
381382
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
382-
383-
def create_custom_forward(module):
384-
def custom_forward(*inputs):
385-
return module(*inputs)
386-
return custom_forward
387383

388384
for block in self.blocks:
389-
if self.training and use_gradient_checkpointing:
390-
if use_gradient_checkpointing_offload:
391-
with torch.autograd.graph.save_on_cpu():
392-
x = torch.utils.checkpoint.checkpoint(
393-
create_custom_forward(block),
394-
x, context, t_mod, freqs,
395-
use_reentrant=False,
396-
)
397-
else:
398-
x = torch.utils.checkpoint.checkpoint(
399-
create_custom_forward(block),
400-
x, context, t_mod, freqs,
401-
use_reentrant=False,
402-
)
385+
if self.training:
386+
x = gradient_checkpoint_forward(
387+
block,
388+
use_gradient_checkpointing,
389+
use_gradient_checkpointing_offload,
390+
x, context, t_mod, freqs
391+
)
403392
else:
404393
x = block(x, context, t_mod, freqs)
405394

‎diffsynth/models/wan_video_dit_s2v.py‎

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch.nn.functional as F
55
from typing import Tuple
66
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
7+
from ..core.gradient import gradient_checkpoint_forward
78

89

910
def torch_dfs(model: nn.Module, parent_name='root'):
@@ -545,46 +546,19 @@ def forward(
545546
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
546547
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
547548

548-
def create_custom_forward(module):
549-
def custom_forward(*inputs):
550-
return module(*inputs)
551-
return custom_forward
552-
553549
for block_id, block in enumerate(self.blocks):
554-
if use_gradient_checkpointing_offload:
555-
with torch.autograd.graph.save_on_cpu():
556-
x = torch.utils.checkpoint.checkpoint(
557-
create_custom_forward(block),
558-
x,
559-
context,
560-
t_mod,
561-
seq_len_x,
562-
pre_compute_freqs[0],
563-
use_reentrant=False,
564-
)
565-
x = torch.utils.checkpoint.checkpoint(
566-
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
567-
x,
568-
use_reentrant=False,
569-
)
570-
elif use_gradient_checkpointing:
571-
x = torch.utils.checkpoint.checkpoint(
572-
create_custom_forward(block),
573-
x,
574-
context,
575-
t_mod,
576-
seq_len_x,
577-
pre_compute_freqs[0],
578-
use_reentrant=False,
579-
)
580-
x = torch.utils.checkpoint.checkpoint(
581-
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
582-
x,
583-
use_reentrant=False,
584-
)
585-
else:
586-
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
587-
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
550+
x = gradient_checkpoint_forward(
551+
block,
552+
use_gradient_checkpointing,
553+
use_gradient_checkpointing_offload,
554+
x, context, t_mod, seq_len_x, pre_compute_freqs[0]
555+
)
556+
x = gradient_checkpoint_forward(
557+
lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
558+
use_gradient_checkpointing,
559+
use_gradient_checkpointing_offload,
560+
x
561+
)
588562

589563
x = x[:, :seq_len_x]
590564
x = self.head(x, t[:-1])

‎diffsynth/models/wan_video_vace.py‎

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from .wan_video_dit import DiTBlock
3-
3+
from ..core.gradient import gradient_checkpoint_forward
44

55
class VaceWanAttentionBlock(DiTBlock):
66
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
@@ -62,26 +62,13 @@ def forward(
6262
dim=1) for u in c
6363
])
6464

65-
def create_custom_forward(module):
66-
def custom_forward(*inputs):
67-
return module(*inputs)
68-
return custom_forward
69-
7065
for block in self.vace_blocks:
71-
if use_gradient_checkpointing_offload:
72-
with torch.autograd.graph.save_on_cpu():
73-
c = torch.utils.checkpoint.checkpoint(
74-
create_custom_forward(block),
75-
c, x, context, t_mod, freqs,
76-
use_reentrant=False,
77-
)
78-
elif use_gradient_checkpointing:
79-
c = torch.utils.checkpoint.checkpoint(
80-
create_custom_forward(block),
81-
c, x, context, t_mod, freqs,
82-
use_reentrant=False,
83-
)
84-
else:
85-
c = block(c, x, context, t_mod, freqs)
66+
c = gradient_checkpoint_forward(
67+
block,
68+
use_gradient_checkpointing,
69+
use_gradient_checkpointing_offload,
70+
c, x, context, t_mod, freqs
71+
)
72+
8673
hints = torch.unbind(c)[:-1]
8774
return hints

0 commit comments

Comments
 (0)