Skip to content

Support DeepSpeed ZeRO 3#1272

Merged
Artiprocher merged 7 commits intomainfrom
zero3-fix
Feb 6, 2026
Merged

Support DeepSpeed ZeRO 3#1272
Artiprocher merged 7 commits intomainfrom
zero3-fix

Conversation

@Artiprocher
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Artiprocher, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the framework's capability to train large models by integrating DeepSpeed ZeRO Stage 3. The changes streamline model initialization and state loading for distributed environments, refactor gradient checkpointing for cleaner code, and adjust VAE feature handling. These updates collectively aim to improve memory efficiency and scalability for training, particularly on NPU setups, and are supported by updated documentation and configuration examples.

Highlights

  • DeepSpeed ZeRO 3 Integration: Implemented comprehensive support for DeepSpeed ZeRO Stage 3, a memory optimization technique for large models. This includes specialized model initialization and state dictionary loading mechanisms to handle parameter partitioning across multiple GPUs efficiently.
  • Gradient Checkpointing Refactoring: Centralized gradient checkpointing logic into a new utility function, gradient_checkpoint_forward, reducing code duplication and improving maintainability across various model architectures (DiT, DiT-S2V, VACE, and xFuser).
  • VAE Feature Cache Management: Modified several VAE-related modules to explicitly return and manage feature caches and indices, which is crucial for consistent behavior in distributed training environments and potentially for memory optimization.
  • Training Script and Documentation Updates: Updated training scripts to include a new --initialize_model_on_cpu argument, enabling models to be initialized on the CPU before moving to accelerators, a common practice with DeepSpeed ZeRO 3. Corresponding documentation (English and Chinese) has been added to guide users on configuring DeepSpeed ZeRO 3 training.
  • New DeepSpeed Configuration Files: Introduced new accelerate_config_zero3.yaml files for various model training examples (Flux, Flux2, Qwen-Image, WanVideo, Z-Image), pre-configured for DeepSpeed ZeRO Stage 3.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • diffsynth/core/loader/model.py
    • Imported contextmanager, is_deepspeed_zero3_enabled, and ContextManagers for DeepSpeed integration.
    • Modified load_model to use ContextManagers(get_init_context(...)) for model initialization, replacing the generic skip_model_initialization().
    • Added conditional logic within load_model to use _load_state_dict_into_zero3_model for DeepSpeed ZeRO 3 enabled environments when loading model state dictionaries.
    • Introduced a new function get_init_context to dynamically provide model initialization contexts based on DeepSpeed ZeRO 3 enablement.
  • diffsynth/diffusion/logger.py
    • Moved the accelerator.get_state_dict(model) call outside the if accelerator.is_main_process block in on_epoch_end and save_model methods to ensure state dict retrieval is handled correctly across all ranks.
  • diffsynth/diffusion/runner.py
    • Added model.to(device=accelerator.device) before accelerator.prepare in launch_training_task and launch_data_process_task for explicit device placement.
  • diffsynth/models/wan_video_dit.py
    • Imported gradient_checkpoint_forward.
    • Replaced custom create_custom_forward function and conditional torch.utils.checkpoint.checkpoint calls with the unified gradient_checkpoint_forward utility for cleaner gradient checkpointing.
  • diffsynth/models/wan_video_dit_s2v.py
    • Imported gradient_checkpoint_forward.
    • Replaced custom create_custom_forward function and conditional torch.utils.checkpoint.checkpoint calls with gradient_checkpoint_forward.
  • diffsynth/models/wan_video_vace.py
    • Imported gradient_checkpoint_forward.
    • Replaced custom create_custom_forward function and conditional torch.utils.checkpoint.checkpoint calls with gradient_checkpoint_forward.
  • diffsynth/models/wan_video_vae.py
    • Modified forward methods in several VAE-related classes (ResidualBlock, Up_ResidualBlock, Encoder3d_Block, Encoder3d, Encoder3d_38, Decoder3d_Block, Decoder3d, Decoder3d_38, AutoencoderKL) to return x, feat_cache, feat_idx instead of just x.
    • Updated calls to these modified forward methods within AutoencoderKL's encode and decode to correctly unpack the returned feat_cache and feat_idx.
  • diffsynth/pipelines/flux2_image.py
    • Removed the with torch.inference_mode(): context manager from around the text_encoder forward pass in get_qwen3_prompt_embeds.
  • diffsynth/pipelines/wan_video.py
    • Removed custom create_custom_forward function.
    • Replaced conditional torch.utils.checkpoint.checkpoint calls with gradient_checkpoint_forward in model_fn_wan_video and its main loop.
  • diffsynth/utils/xfuser/xdit_context_parallel.py
    • Imported gradient_checkpoint_forward.
    • Removed custom create_custom_forward function.
    • Replaced conditional torch.utils.checkpoint.checkpoint calls with gradient_checkpoint_forward.
  • docs/en/Model_Details/Qwen-Image.md
    • Added a new section detailing DeepSpeed ZeRO Stage 3 training for Qwen-Image models, including required --config_file and --initialize_model_on_cpu parameters.
  • docs/en/Model_Details/Wan.md
    • Added a new section detailing DeepSpeed ZeRO Stage 3 training for Wan series models, including required --config_file and --initialize_model_on_cpu parameters.
  • docs/en/Pipeline_Usage/GPU_support.md
    • Updated the GPU support documentation to include Qwen-Image series models as requiring the --initialize_model_on_cpu parameter.
  • docs/zh/Model_Details/Qwen-Image.md
    • Added a Chinese section on DeepSpeed ZeRO 3 training for Qwen-Image models.
  • docs/zh/Model_Details/Wan.md
    • Added a Chinese section on DeepSpeed ZeRO 3 training for Wan series models.
  • docs/zh/Pipeline_Usage/GPU_support.md
    • Updated the Chinese GPU support documentation to include Qwen-Image系列 models as requiring the --initialize_model_on_cpu parameter.
  • examples/flux/model_training/full/accelerate_config_zero3.yaml
    • Added a new DeepSpeed configuration file with zero_stage: 3, zero3_init_flag: true, zero3_save_16bit_model: true, and distributed_type: DEEPSPEED.
  • examples/flux2/model_training/full/accelerate_config_zero3.yaml
    • Added a new DeepSpeed configuration file with zero_stage: 3, zero3_init_flag: true, zero3_save_16bit_model: true, and distributed_type: DEEPSPEED.
  • examples/flux2/model_training/special/npu_training/FLUX.2-dev-Lora-NPU.sh
    • Added a new NPU training script that utilizes accelerate_config_zero3.yaml and the --initialize_model_on_cpu flag for DeepSpeed training.
  • examples/flux2/model_training/special/npu_training/FLUX.2-klein-9B-NPU.sh
    • Added a new NPU training script for FLUX.2-klein-9B models, including PYTORCH_NPU_ALLOC_CONF and CPU_AFFINITY_CONF environment variables.
  • examples/flux2/model_training/train.py
    • Added --initialize_model_on_cpu as a command-line argument.
    • Modified model initialization to conditionally set the device to 'cpu' if --initialize_model_on_cpu is enabled.
  • examples/qwen_image/model_training/full/accelerate_config_zero3.yaml
    • Added a new DeepSpeed configuration file with zero_stage: 3, zero3_init_flag: true, zero3_save_16bit_model: true, and distributed_type: DEEPSPEED.
  • examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-NPU.sh
    • Added a new NPU training script that utilizes accelerate_config_zero3.yaml and the --initialize_model_on_cpu flag for DeepSpeed training.
  • examples/qwen_image/model_training/train.py
    • Added --initialize_model_on_cpu as a command-line argument.
    • Modified model initialization to conditionally set the device to 'cpu' if --initialize_model_on_cpu is enabled.
  • examples/wanvideo/model_training/full/accelerate_config_zero3.yaml
    • Added a new DeepSpeed configuration file with zero_stage: 3, zero3_init_flag: true, zero3_save_16bit_model: true, and distributed_type: DEEPSPEED.
  • examples/z_image/model_training/full/accelerate_config_zero3.yaml
    • Added a new DeepSpeed configuration file with zero_stage: 3, zero3_init_flag: true, zero3_save_16bit_model: true, and distributed_type: DEEPSPEED.
Activity
  • The pull request was created by Artiprocher to introduce DeepSpeed ZeRO 3 support.
  • Changes span across core loading mechanisms, diffusion training utilities, model architectures, pipelines, and documentation.
  • New configuration files and NPU training scripts have been added to facilitate the use of DeepSpeed ZeRO 3.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@Artiprocher Artiprocher merged commit 1b47e1d into main Feb 6, 2026
@Artiprocher Artiprocher deleted the zero3-fix branch February 6, 2026 08:33
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DeepSpeed ZeRO Stage 3, which is a significant enhancement for training large models. The changes for this are well-implemented across model loading, training execution, and configuration. I've also noticed a substantial refactoring of gradient checkpointing logic into a centralized utility function, which greatly improves code clarity and maintainability. Additionally, there are several documentation updates and new example scripts that will be very helpful for users. I found a couple of potential issues, one related to an unused function parameter and another that could lead to a runtime error due to inconsistent return values in a function. Overall, this is a great contribution.

@@ -511,7 +511,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The forward method of Up_ResidualBlock has inconsistent return values. The else branch was updated to return three values (x_main, feat_cache, feat_idx), but this if branch still returns only one value. This will likely cause an UnpackingError at the call site in Decoder3d_38.forward (line 916), which expects three return values.

Suggested change
return x_main + x_shortcut
return x_main + x_shortcut, feat_cache, feat_idx



def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The state_dict parameter is added to the function signature but it appears to be unused. It gets unconditionally overwritten later in the function (e.g., on lines 39, 41, 46, 48) without its initial value ever being checked. If the intention is to allow passing a pre-loaded state dictionary, the logic should be updated to use this parameter when it's provided. Otherwise, it should be removed to avoid confusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants