@@ -38,6 +38,7 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
3838 Flux2Unit_Qwen3PromptEmbedder (),
3939 Flux2Unit_NoiseInitializer (),
4040 Flux2Unit_InputImageEmbedder (),
41+ Flux2Unit_EditImageEmbedder (),
4142 Flux2Unit_ImageIDs (),
4243 ]
4344 self .model_fn = model_fn_flux2
@@ -80,6 +81,9 @@ def __call__(
8081 # Image
8182 input_image : Image .Image = None ,
8283 denoising_strength : float = 1.0 ,
84+ # Edit
85+ edit_image : Union [Image .Image , List [Image .Image ]] = None ,
86+ edit_image_auto_resize : bool = True ,
8387 # Shape
8488 height : int = 1024 ,
8589 width : int = 1024 ,
@@ -103,6 +107,7 @@ def __call__(
103107 inputs_shared = {
104108 "cfg_scale" : cfg_scale , "embedded_guidance" : embedded_guidance ,
105109 "input_image" : input_image , "denoising_strength" : denoising_strength ,
110+ "edit_image" : edit_image , "edit_image_auto_resize" : edit_image_auto_resize ,
106111 "height" : height , "width" : width ,
107112 "seed" : seed , "rand_device" : rand_device ,
108113 "num_inference_steps" : num_inference_steps ,
@@ -457,6 +462,64 @@ def process(self, pipe: Flux2ImagePipeline, input_image, noise):
457462 return {"latents" : latents , "input_latents" : input_latents }
458463
459464
465+ class Flux2Unit_EditImageEmbedder (PipelineUnit ):
466+ def __init__ (self ):
467+ super ().__init__ (
468+ input_params = ("edit_image" , "edit_image_auto_resize" ),
469+ output_params = ("edit_latents" , "edit_image_ids" ),
470+ onload_model_names = ("vae" ,)
471+ )
472+
473+ def calculate_dimensions (self , target_area , ratio ):
474+ import math
475+ width = math .sqrt (target_area * ratio )
476+ height = width / ratio
477+ width = round (width / 32 ) * 32
478+ height = round (height / 32 ) * 32
479+ return width , height
480+
481+ def edit_image_auto_resize (self , edit_image ):
482+ calculated_width , calculated_height = self .calculate_dimensions (1024 * 1024 , edit_image .size [0 ] / edit_image .size [1 ])
483+ return edit_image .resize ((calculated_width , calculated_height ))
484+
485+ def process_image_ids (self , image_latents , scale = 10 ):
486+ t_coords = [scale + scale * t for t in torch .arange (0 , len (image_latents ))]
487+ t_coords = [t .view (- 1 ) for t in t_coords ]
488+
489+ image_latent_ids = []
490+ for x , t in zip (image_latents , t_coords ):
491+ x = x .squeeze (0 )
492+ _ , height , width = x .shape
493+
494+ x_ids = torch .cartesian_prod (t , torch .arange (height ), torch .arange (width ), torch .arange (1 ))
495+ image_latent_ids .append (x_ids )
496+
497+ image_latent_ids = torch .cat (image_latent_ids , dim = 0 )
498+ image_latent_ids = image_latent_ids .unsqueeze (0 )
499+
500+ return image_latent_ids
501+
502+ def process (self , pipe : Flux2ImagePipeline , edit_image , edit_image_auto_resize ):
503+ if edit_image is None :
504+ return {}
505+ pipe .load_models_to_device (self .onload_model_names )
506+ if isinstance (edit_image , Image .Image ):
507+ edit_image = [edit_image ]
508+ resized_edit_image , edit_latents = [], []
509+ for image in edit_image :
510+ # Preprocess
511+ if edit_image_auto_resize is None or edit_image_auto_resize :
512+ image = self .edit_image_auto_resize (image )
513+ resized_edit_image .append (image )
514+ # Encode
515+ image = pipe .preprocess_image (image )
516+ latents = pipe .vae .encode (image )
517+ edit_latents .append (latents )
518+ edit_image_ids = self .process_image_ids (edit_latents ).to (pipe .device )
519+ edit_latents = torch .concat ([rearrange (latents , "B C H W -> B (H W) C" ) for latents in edit_latents ], dim = 1 )
520+ return {"edit_latents" : edit_latents , "edit_image_ids" : edit_image_ids }
521+
522+
460523class Flux2Unit_ImageIDs (PipelineUnit ):
461524 def __init__ (self ):
462525 super ().__init__ (
@@ -491,10 +554,17 @@ def model_fn_flux2(
491554 prompt_embeds = None ,
492555 text_ids = None ,
493556 image_ids = None ,
557+ edit_latents = None ,
558+ edit_image_ids = None ,
494559 use_gradient_checkpointing = False ,
495560 use_gradient_checkpointing_offload = False ,
496561 ** kwargs ,
497562):
563+ image_seq_len = latents .shape [1 ]
564+ if edit_latents is not None :
565+ image_seq_len = latents .shape [1 ]
566+ latents = torch .concat ([latents , edit_latents ], dim = 1 )
567+ image_ids = torch .concat ([image_ids , edit_image_ids ], dim = 1 )
498568 embedded_guidance = torch .tensor ([embedded_guidance ], device = latents .device )
499569 model_output = dit (
500570 hidden_states = latents ,
@@ -506,4 +576,5 @@ def model_fn_flux2(
506576 use_gradient_checkpointing = use_gradient_checkpointing ,
507577 use_gradient_checkpointing_offload = use_gradient_checkpointing_offload ,
508578 )
579+ model_output = model_output [:, :image_seq_len ]
509580 return model_output
0 commit comments