Skip to content

Commit 3743b13

Browse files
authored
Merge pull request #1219 from modelscope/klein-edit
support klein edit
2 parents 3e4b47e + a835df9 commit 3743b13

18 files changed

+273
-2
lines changed

‎diffsynth/pipelines/flux2_image.py‎

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
3838
Flux2Unit_Qwen3PromptEmbedder(),
3939
Flux2Unit_NoiseInitializer(),
4040
Flux2Unit_InputImageEmbedder(),
41+
Flux2Unit_EditImageEmbedder(),
4142
Flux2Unit_ImageIDs(),
4243
]
4344
self.model_fn = model_fn_flux2
@@ -80,6 +81,9 @@ def __call__(
8081
# Image
8182
input_image: Image.Image = None,
8283
denoising_strength: float = 1.0,
84+
# Edit
85+
edit_image: Union[Image.Image, List[Image.Image]] = None,
86+
edit_image_auto_resize: bool = True,
8387
# Shape
8488
height: int = 1024,
8589
width: int = 1024,
@@ -103,6 +107,7 @@ def __call__(
103107
inputs_shared = {
104108
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
105109
"input_image": input_image, "denoising_strength": denoising_strength,
110+
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
106111
"height": height, "width": width,
107112
"seed": seed, "rand_device": rand_device,
108113
"num_inference_steps": num_inference_steps,
@@ -457,6 +462,64 @@ def process(self, pipe: Flux2ImagePipeline, input_image, noise):
457462
return {"latents": latents, "input_latents": input_latents}
458463

459464

465+
class Flux2Unit_EditImageEmbedder(PipelineUnit):
466+
def __init__(self):
467+
super().__init__(
468+
input_params=("edit_image", "edit_image_auto_resize"),
469+
output_params=("edit_latents", "edit_image_ids"),
470+
onload_model_names=("vae",)
471+
)
472+
473+
def calculate_dimensions(self, target_area, ratio):
474+
import math
475+
width = math.sqrt(target_area * ratio)
476+
height = width / ratio
477+
width = round(width / 32) * 32
478+
height = round(height / 32) * 32
479+
return width, height
480+
481+
def edit_image_auto_resize(self, edit_image):
482+
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
483+
return edit_image.resize((calculated_width, calculated_height))
484+
485+
def process_image_ids(self, image_latents, scale=10):
486+
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
487+
t_coords = [t.view(-1) for t in t_coords]
488+
489+
image_latent_ids = []
490+
for x, t in zip(image_latents, t_coords):
491+
x = x.squeeze(0)
492+
_, height, width = x.shape
493+
494+
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
495+
image_latent_ids.append(x_ids)
496+
497+
image_latent_ids = torch.cat(image_latent_ids, dim=0)
498+
image_latent_ids = image_latent_ids.unsqueeze(0)
499+
500+
return image_latent_ids
501+
502+
def process(self, pipe: Flux2ImagePipeline, edit_image, edit_image_auto_resize):
503+
if edit_image is None:
504+
return {}
505+
pipe.load_models_to_device(self.onload_model_names)
506+
if isinstance(edit_image, Image.Image):
507+
edit_image = [edit_image]
508+
resized_edit_image, edit_latents = [], []
509+
for image in edit_image:
510+
# Preprocess
511+
if edit_image_auto_resize is None or edit_image_auto_resize:
512+
image = self.edit_image_auto_resize(image)
513+
resized_edit_image.append(image)
514+
# Encode
515+
image = pipe.preprocess_image(image)
516+
latents = pipe.vae.encode(image)
517+
edit_latents.append(latents)
518+
edit_image_ids = self.process_image_ids(edit_latents).to(pipe.device)
519+
edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1)
520+
return {"edit_latents": edit_latents, "edit_image_ids": edit_image_ids}
521+
522+
460523
class Flux2Unit_ImageIDs(PipelineUnit):
461524
def __init__(self):
462525
super().__init__(
@@ -491,10 +554,17 @@ def model_fn_flux2(
491554
prompt_embeds=None,
492555
text_ids=None,
493556
image_ids=None,
557+
edit_latents=None,
558+
edit_image_ids=None,
494559
use_gradient_checkpointing=False,
495560
use_gradient_checkpointing_offload=False,
496561
**kwargs,
497562
):
563+
image_seq_len = latents.shape[1]
564+
if edit_latents is not None:
565+
image_seq_len = latents.shape[1]
566+
latents = torch.concat([latents, edit_latents], dim=1)
567+
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
498568
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
499569
model_output = dit(
500570
hidden_states=latents,
@@ -506,4 +576,5 @@ def model_fn_flux2(
506576
use_gradient_checkpointing=use_gradient_checkpointing,
507577
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
508578
)
579+
model_output = model_output[:, :image_seq_len]
509580
return model_output

‎examples/flux2/model_inference/FLUX.2-klein-4B.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@
1515
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
1616
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
1717
image.save("image_FLUX.2-klein-4B.jpg")
18+
19+
prompt = "change the color of the clothes to red"
20+
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
21+
image.save("image_edit_FLUX.2-klein-4B.jpg")

‎examples/flux2/model_inference/FLUX.2-klein-9B.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@
1515
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
1616
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
1717
image.save("image_FLUX.2-klein-9B.jpg")
18+
19+
prompt = "change the color of the clothes to red"
20+
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
21+
image.save("image_edit_FLUX.2-klein-9B.jpg")

‎examples/flux2/model_inference/FLUX.2-klein-base-4B.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@
1515
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
1616
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
1717
image.save("image_FLUX.2-klein-base-4B.jpg")
18+
19+
prompt = "change the color of the clothes to red"
20+
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
21+
image.save("image_edit_FLUX.2-klein-base-4B.jpg")

‎examples/flux2/model_inference/FLUX.2-klein-base-9B.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@
1515
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
1616
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
1717
image.save("image_FLUX.2-klein-base-9B.jpg")
18+
19+
prompt = "change the color of the clothes to red"
20+
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
21+
image.save("image_edit_FLUX.2-klein-base-9B.jpg")

‎examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@
2525
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
2626
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
2727
image.save("image_FLUX.2-klein-4B.jpg")
28+
29+
prompt = "change the color of the clothes to red"
30+
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
31+
image.save("image_edit_FLUX.2-klein-4B.jpg")

‎examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@
2525
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
2626
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
2727
image.save("image_FLUX.2-klein-9B.jpg")
28+
29+
prompt = "change the color of the clothes to red"
30+
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
31+
image.save("image_edit_FLUX.2-klein-9B.jpg")

‎examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@
2525
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
2626
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
2727
image.save("image_FLUX.2-klein-base-4B.jpg")
28+
29+
prompt = "change the color of the clothes to red"
30+
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
31+
image.save("image_edit_FLUX.2-klein-base-4B.jpg")

‎examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@
2525
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
2626
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
2727
image.save("image_FLUX.2-klein-base-9B.jpg")
28+
29+
prompt = "change the color of the clothes to red"
30+
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
31+
image.save("image_edit_FLUX.2-klein-base-9B.jpg")

‎examples/flux2/model_training/full/FLUX.2-klein-4B.sh‎

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,20 @@ accelerate launch examples/flux2/model_training/train.py \
1111
--output_path "./models/train/FLUX.2-klein-4B_full" \
1212
--trainable_models "dit" \
1313
--use_gradient_checkpointing
14+
15+
# Edit
16+
# accelerate launch examples/flux2/model_training/train.py \
17+
# --dataset_base_path data/example_image_dataset \
18+
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
19+
# --data_file_keys "image,edit_image" \
20+
# --extra_inputs "edit_image" \
21+
# --max_pixels 1048576 \
22+
# --dataset_repeat 50 \
23+
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
24+
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
25+
# --learning_rate 1e-5 \
26+
# --num_epochs 2 \
27+
# --remove_prefix_in_ckpt "pipe.dit." \
28+
# --output_path "./models/train/FLUX.2-klein-4B_full" \
29+
# --trainable_models "dit" \
30+
# --use_gradient_checkpointing

0 commit comments

Comments
 (0)