1212from diffsynth_engine .pipelines .qwen_image import QwenImagePipeline
1313from diffsynth_engine .models .qwen_image import QwenImageVAE
1414from diffsynth_engine .models .basic .lora import LoRALinear
15- from diffsynth_engine .models .qwen_image .qwen_image_dit import QwenImageTransformerBlock
15+ from diffsynth_engine .models .qwen_image .qwen_image_dit import QwenImageTransformerBlock , QwenEmbedRope
1616from diffsynth_engine .utils import logging
1717from diffsynth_engine .utils .loader import load_file
1818from diffsynth_engine .utils .download import fetch_model
@@ -32,6 +32,7 @@ def odtsr_forward():
3232 """
3333 original_lora_forward = LoRALinear .forward
3434 original_modulate = QwenImageTransformerBlock ._modulate
35+ original_rope_forward = QwenEmbedRope .forward
3536
3637 def lora_batch_cfg_forward (self , x ):
3738 y = nn .Linear .forward (self , x )
@@ -50,6 +51,49 @@ def lora_batch_cfg_forward(self, x):
5051 y [:, L :] += lora (x2 )
5152 return y
5253
54+ def optimized_rope_forward (self , video_fhw , txt_length , device ):
55+ if self .pos_freqs .device != device :
56+ self .pos_freqs = self .pos_freqs .to (device )
57+ self .neg_freqs = self .neg_freqs .to (device )
58+
59+ vid_freqs = []
60+ max_vid_index = 0
61+ idx = 0
62+ for fhw in video_fhw :
63+ frame , height , width = fhw
64+ rope_key = f"{ idx } _{ height } _{ width } "
65+
66+ if rope_key not in self .rope_cache :
67+ seq_lens = frame * height * width
68+ freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
69+ freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
70+ freqs_frame = freqs_pos [0 ][idx : idx + frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
71+ if self .scale_rope :
72+ freqs_height = torch .cat (
73+ [freqs_neg [1 ][- (height - height // 2 ) :], freqs_pos [1 ][: height // 2 ]], dim = 0
74+ )
75+ freqs_height = freqs_height .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
76+ freqs_width = torch .cat ([freqs_neg [2 ][- (width - width // 2 ) :], freqs_pos [2 ][: width // 2 ]], dim = 0 )
77+ freqs_width = freqs_width .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
78+
79+ else :
80+ freqs_height = freqs_pos [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
81+ freqs_width = freqs_pos [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
82+
83+ freqs = torch .cat ([freqs_frame , freqs_height , freqs_width ], dim = - 1 ).reshape (seq_lens , - 1 )
84+ self .rope_cache [rope_key ] = freqs .clone ().contiguous ()
85+ vid_freqs .append (self .rope_cache [rope_key ])
86+ if self .scale_rope :
87+ max_vid_index = max (height // 2 , width // 2 , max_vid_index )
88+ else :
89+ max_vid_index = max (height , width , max_vid_index )
90+
91+ txt_freqs = self .pos_freqs [max_vid_index : max_vid_index + txt_length , ...]
92+ vid_freqs = torch .cat (vid_freqs , dim = 0 )
93+
94+ return vid_freqs , txt_freqs
95+
96+
5397 def optimized_modulate (self , x , mod_params , index = None ):
5498 if mod_params .ndim == 2 :
5599 shift , scale , gate = mod_params .chunk (3 , dim = - 1 )
@@ -72,12 +116,14 @@ def optimized_modulate(self, x, mod_params, index=None):
72116
73117 LoRALinear .forward = lora_batch_cfg_forward
74118 QwenImageTransformerBlock ._modulate = optimized_modulate
119+ QwenEmbedRope .forward = optimized_rope_forward
75120
76121 try :
77122 yield
78123 finally :
79124 LoRALinear .forward = original_lora_forward
80125 QwenImageTransformerBlock ._modulate = original_modulate
126+ QwenEmbedRope .forward = original_rope_forward
81127
82128
83129class QwenImageUpscalerTool :
0 commit comments