Skip to content

Commit 7eba415

Browse files
authored
init flux2 dit on meta device (#233)
1 parent b6fea88 commit 7eba415

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

‎diffsynth_engine/models/flux2/flux2_dit.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,8 @@ def from_state_dict(
10581058
dtype: torch.dtype = torch.float32,
10591059
**kwargs,
10601060
) -> "Flux2DiT":
1061-
model = cls(device="meta", dtype=dtype, **kwargs)
1061+
with torch.device("meta"):
1062+
model = cls(device="meta", dtype=dtype, **kwargs)
10621063
model = model.requires_grad_(False)
10631064
model.load_state_dict(state_dict, assign=True)
10641065
model.to(device=device, dtype=dtype, non_blocking=True)

‎diffsynth_engine/pipelines/flux2_klein_image.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def _from_state_dict(cls, state_dicts: Flux2StateDicts, config: Flux2KleinPipeli
202202
else:
203203
with open(FLUX2_TEXT_ENCODER_8B_CONF_PATH, "r", encoding="utf-8") as f:
204204
qwen3_config = Qwen3Config(**json.load(f))
205-
state_dicts.encoder.pop("lm_head.weight")
205+
if "lm_head.weight" in state_dicts.encoder:
206+
state_dicts.encoder.pop("lm_head.weight")
206207
dit_config = {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
207208
text_encoder = Qwen3Model.from_state_dict(
208209
state_dicts.encoder, config=qwen3_config, device=init_device, dtype=config.encoder_dtype

0 commit comments

Comments
 (0)