Skip to content

Commit a3b1585

Browse files
authored
fix odtsr rope (#230)
* fix odtsr rope * fix syntax error
1 parent 7a79f14 commit a3b1585

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

‎diffsynth_engine/tools/qwen_image_upscaler_tool.py‎

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from diffsynth_engine.pipelines.qwen_image import QwenImagePipeline
1313
from diffsynth_engine.models.qwen_image import QwenImageVAE
1414
from diffsynth_engine.models.basic.lora import LoRALinear
15-
from diffsynth_engine.models.qwen_image.qwen_image_dit import QwenImageTransformerBlock
15+
from diffsynth_engine.models.qwen_image.qwen_image_dit import QwenImageTransformerBlock, QwenEmbedRope
1616
from diffsynth_engine.utils import logging
1717
from diffsynth_engine.utils.loader import load_file
1818
from diffsynth_engine.utils.download import fetch_model
@@ -32,6 +32,7 @@ def odtsr_forward():
3232
"""
3333
original_lora_forward = LoRALinear.forward
3434
original_modulate = QwenImageTransformerBlock._modulate
35+
original_rope_forward = QwenEmbedRope.forward
3536

3637
def lora_batch_cfg_forward(self, x):
3738
y = nn.Linear.forward(self, x)
@@ -50,6 +51,49 @@ def lora_batch_cfg_forward(self, x):
5051
y[:, L:] += lora(x2)
5152
return y
5253

54+
def optimized_rope_forward(self, video_fhw, txt_length, device):
55+
if self.pos_freqs.device != device:
56+
self.pos_freqs = self.pos_freqs.to(device)
57+
self.neg_freqs = self.neg_freqs.to(device)
58+
59+
vid_freqs = []
60+
max_vid_index = 0
61+
idx = 0
62+
for fhw in video_fhw:
63+
frame, height, width = fhw
64+
rope_key = f"{idx}_{height}_{width}"
65+
66+
if rope_key not in self.rope_cache:
67+
seq_lens = frame * height * width
68+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
69+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
70+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
71+
if self.scale_rope:
72+
freqs_height = torch.cat(
73+
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
74+
)
75+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
76+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
77+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
78+
79+
else:
80+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
81+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
82+
83+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
84+
self.rope_cache[rope_key] = freqs.clone().contiguous()
85+
vid_freqs.append(self.rope_cache[rope_key])
86+
if self.scale_rope:
87+
max_vid_index = max(height // 2, width // 2, max_vid_index)
88+
else:
89+
max_vid_index = max(height, width, max_vid_index)
90+
91+
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + txt_length, ...]
92+
vid_freqs = torch.cat(vid_freqs, dim=0)
93+
94+
return vid_freqs, txt_freqs
95+
96+
5397
def optimized_modulate(self, x, mod_params, index=None):
5498
if mod_params.ndim == 2:
5599
shift, scale, gate = mod_params.chunk(3, dim=-1)
@@ -72,12 +116,14 @@ def optimized_modulate(self, x, mod_params, index=None):
72116

73117
LoRALinear.forward = lora_batch_cfg_forward
74118
QwenImageTransformerBlock._modulate = optimized_modulate
119+
QwenEmbedRope.forward = optimized_rope_forward
75120

76121
try:
77122
yield
78123
finally:
79124
LoRALinear.forward = original_lora_forward
80125
QwenImageTransformerBlock._modulate = original_modulate
126+
QwenEmbedRope.forward = original_rope_forward
81127

82128

83129
class QwenImageUpscalerTool:

0 commit comments

Comments
 (0)