diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index ffc77bb24fdb..c04130192947 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -13,11 +13,12 @@ # limitations under the License. import inspect -from typing import List, Optional, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np import torch +from ...models import AutoencoderKL from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor @@ -103,6 +104,62 @@ def calculate_shift( return mu +# Adapted from the original implementation. +def prepare_latents_img2img( + vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator +): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + latent_channels = vae.config.latent_channels + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != latent_channels: + image_latents = _encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = scheduler.scale_noise(image_latents, timestep, noise) + latents = _pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) @@ -125,6 +182,55 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): return latent_image_ids.to(device=device, dtype=dtype) +# Cannot use "# Copied from" because it introduces weird indentation errors. +def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image), generator=generator) + + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + + return image_latents + + +def _get_initial_timesteps_and_optionals( + transformer, + scheduler, + batch_size, + height, + width, + vae_scale_factor, + num_inference_steps, + guidance_scale, + sigmas, + device, +): + image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas: + sigmas = None + mu = calculate_shift( + image_seq_len, + scheduler.config.get("base_image_seq_len", 256), + scheduler.config.get("max_image_seq_len", 4096), + scheduler.config.get("base_shift", 0.5), + scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu) + if transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(batch_size) + else: + guidance = None + + return timesteps, num_inference_steps, sigmas, guidance + + class FluxInputStep(PipelineBlock): model_name = "flux" @@ -234,18 +340,20 @@ def inputs(self) -> List[InputParam]: InputParam("timesteps"), InputParam("sigmas"), InputParam("guidance_scale", default=3.5), - InputParam("latents", type_hint=torch.Tensor), + InputParam("num_images_per_prompt", default=1), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), ] @property def intermediate_inputs(self) -> List[str]: return [ InputParam( - "latents", + "batch_size", required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", - ) + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), ] @property @@ -264,34 +372,127 @@ def intermediate_outputs(self) -> List[OutputParam]: def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device - scheduler = components.scheduler - latents = block_state.latents - image_seq_len = latents.shape[1] + scheduler = components.scheduler + transformer = components.transformer - num_inference_steps = block_state.num_inference_steps - sigmas = block_state.sigmas - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas - if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas: - sigmas = None + batch_size = block_state.batch_size * block_state.num_images_per_prompt + timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals( + transformer, + scheduler, + batch_size, + block_state.height, + block_state.width, + components.vae_scale_factor, + block_state.num_inference_steps, + block_state.guidance_scale, + block_state.sigmas, + block_state.device, + ) + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps block_state.sigmas = sigmas - mu = calculate_shift( - image_seq_len, - scheduler.config.get("base_image_seq_len", 256), - scheduler.config.get("max_image_seq_len", 4096), - scheduler.config.get("base_shift", 0.5), - scheduler.config.get("max_shift", 1.15), + block_state.guidance = guidance + + self.set_block_state(state, block_state) + return components, state + + +class FluxImg2ImgSetTimestepsStep(PipelineBlock): + model_name = "flux" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("strength", default=0.6), + InputParam("guidance_scale", default=3.5), + InputParam("num_images_per_prompt", default=1), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + OutputParam( + "latent_timestep", + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image generation", + ), + OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self.scheduler->scheduler + def get_timesteps(scheduler, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = scheduler.timesteps[t_start * scheduler.order :] + if hasattr(scheduler, "set_begin_index"): + scheduler.set_begin_index(t_start * scheduler.order) + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + scheduler = components.scheduler + transformer = components.transformer + batch_size = block_state.batch_size * block_state.num_images_per_prompt + timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals( + transformer, + scheduler, + batch_size, + block_state.height, + block_state.width, + components.vae_scale_factor, + block_state.num_inference_steps, + block_state.guidance_scale, + block_state.sigmas, + block_state.device, ) - block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu + timesteps, num_inference_steps = self.get_timesteps( + scheduler, num_inference_steps, block_state.strength, block_state.device ) - if components.transformer.config.guidance_embeds: - guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + block_state.sigmas = sigmas block_state.guidance = guidance + block_state.latent_timestep = timesteps[:1].repeat(batch_size) + self.set_block_state(state, block_state) return components, state @@ -305,7 +506,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "Prepare latents step that prepares the latents for the text-to-video generation process" + return "Prepare latents step that prepares the latents for the text-to-image generation process" @property def inputs(self) -> List[InputParam]: @@ -402,10 +603,10 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip block_state.num_channels_latents = components.num_channels_latents self.check_inputs(components, block_state) - + batch_size = block_state.batch_size * block_state.num_images_per_prompt block_state.latents, block_state.latent_image_ids = self.prepare_latents( components, - block_state.batch_size * block_state.num_images_per_prompt, + batch_size, block_state.num_channels_latents, block_state.height, block_state.width, @@ -418,3 +619,95 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip self.set_block_state(state, block_state) return components, state + + +class FluxImg2ImgPrepareLatentsStep(PipelineBlock): + model_name = "flux" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that prepares the latents for the image-to-image generation process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=Optional[torch.Tensor]), + InputParam("num_images_per_prompt", type_hint=int, default=1), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", + ), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + OutputParam( + "latent_image_ids", + type_hint=torch.Tensor, + description="IDs computed from the image sequence needed for RoPE", + ), + ] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.device = components._execution_device + block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this? + block_state.num_channels_latents = components.num_channels_latents + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + # TODO: implement `check_inputs` + batch_size = block_state.batch_size * block_state.num_images_per_prompt + if block_state.latents is None: + block_state.latents, block_state.latent_image_ids = prepare_latents_img2img( + components.vae, + components.scheduler, + block_state.image_latents, + block_state.latent_timestep, + batch_size, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py index c4619c17fb0e..79b825a0e739 100644 --- a/src/diffusers/modular_pipelines/flux/denoise.py +++ b/src/diffusers/modular_pipelines/flux/denoise.py @@ -226,5 +226,5 @@ def description(self) -> str: "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" " - `FluxLoopDenoiser`\n" " - `FluxLoopAfterDenoiser`\n" - "This block supports text2image tasks." + "This block supports both text2image and img2img tasks." ) diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py index 9bf2f54eece3..73ccd040afaf 100644 --- a/src/diffusers/modular_pipelines/flux/encoders.py +++ b/src/diffusers/modular_pipelines/flux/encoders.py @@ -19,7 +19,10 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam @@ -50,6 +53,110 @@ def prompt_clean(text): return text +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class FluxVaeEncoderStep(PipelineBlock): + model_name = "flux" + + @property + def description(self) -> str: + return "Vae Encoder step that encode the input image into a latent representation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [InputParam("image", required=True), InputParam("height"), InputParam("width")] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam( + "preprocess_kwargs", + type_hint=Optional[dict], + description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation", + ) + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae + def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image), generator=generator) + + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + + return image_latents + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} + block_state.device = components._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + + block_state.image = components.image_processor.preprocess( + block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs + ) + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + + block_state.batch_size = block_state.image.shape[0] + + # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) + if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." + ) + + block_state.image_latents = self._encode_vae_image( + components.vae, image=block_state.image, generator=block_state.generator + ) + + self.set_block_state(state, block_state) + + return components, state + + class FluxTextEncoderStep(PipelineBlock): model_name = "flux" @@ -297,7 +404,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip prompt_embeds=None, pooled_prompt_embeds=None, device=block_state.device, - num_images_per_prompt=1, # hardcoded for now. + num_images_per_prompt=1, # TODO: hardcoded for now. lora_scale=block_state.text_encoder_lora_scale, ) diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py index b17067303785..04b439f026a4 100644 --- a/src/diffusers/modular_pipelines/flux/modular_blocks.py +++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py @@ -15,16 +15,38 @@ from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict -from .before_denoise import FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep +from .before_denoise import ( + FluxImg2ImgPrepareLatentsStep, + FluxImg2ImgSetTimestepsStep, + FluxInputStep, + FluxPrepareLatentsStep, + FluxSetTimestepsStep, +) from .decoders import FluxDecodeStep from .denoise import FluxDenoiseStep -from .encoders import FluxTextEncoderStep +from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# before_denoise: text2vid +# vae encoder (run before before_denoise) +class FluxAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [FluxVaeEncoderStep] + block_names = ["img2img"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block that works for img2img tasks.\n" + + " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided." + + " - if `image` is provided, step will be skipped." + ) + + +# before_denoise: text2img, img2img class FluxBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [ FluxInputStep, @@ -44,11 +66,27 @@ def description(self): ) -# before_denoise: all task (text2vid,) +# before_denoise: img2img +class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep] + block_names = ["input", "set_timesteps", "prepare_latents"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + + "This is a sequential pipeline blocks:\n" + + " - `FluxInputStep` is used to adjust the batch size of the model inputs\n" + + " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n" + + " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + ) + + +# before_denoise: all task (text2img, img2img) class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [FluxBeforeDenoiseStep] - block_names = ["text2image"] - block_trigger_inputs = [None] + block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep] + block_names = ["text2image", "img2img"] + block_trigger_inputs = [None, "image_latents"] @property def description(self): @@ -56,6 +94,7 @@ def description(self): "Before denoise step that prepare the inputs for the denoise step.\n" + "This is an auto pipeline block that works for text2image.\n" + " - `FluxBeforeDenoiseStep` (text2image) is used.\n" + + " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" ) @@ -69,8 +108,8 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks): def description(self) -> str: return ( "Denoise step that iteratively denoise the latents. " - "This is a auto pipeline block that works for text2image tasks." - " - `FluxDenoiseStep` (denoise) for text2image tasks." + "This is a auto pipeline block that works for text2image and img2img tasks." + " - `FluxDenoiseStep` (denoise) for text2image and img2img tasks." ) @@ -82,19 +121,26 @@ class FluxAutoDecodeStep(AutoPipelineBlocks): @property def description(self): - return "Decode step that decode the denoised latents into videos outputs.\n - `FluxDecodeStep`" + return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`" # text2image class FluxAutoBlocks(SequentialPipelineBlocks): - block_classes = [FluxTextEncoderStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep, FluxAutoDecodeStep] - block_names = ["text_encoder", "before_denoise", "denoise", "decoder"] + block_classes = [ + FluxTextEncoderStep, + FluxAutoVaeEncoderStep, + FluxAutoBeforeDenoiseStep, + FluxAutoDenoiseStep, + FluxAutoDecodeStep, + ] + block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"] @property def description(self): return ( - "Auto Modular pipeline for text-to-image using Flux.\n" - + "- for text-to-image generation, all you need to provide is `prompt`" + "Auto Modular pipeline for text-to-image and image-to-image using Flux.\n" + + "- for text-to-image generation, all you need to provide is `prompt`\n" + + "- for image-to-image generation, you need to provide either `image` or `image_latents`" ) @@ -102,19 +148,29 @@ def description(self): [ ("text_encoder", FluxTextEncoderStep), ("input", FluxInputStep), - ("prepare_latents", FluxPrepareLatentsStep), - # Setting it after preparation of latents because we rely on `latents` - # to calculate `img_seq_len` for `shift`. ("set_timesteps", FluxSetTimestepsStep), + ("prepare_latents", FluxPrepareLatentsStep), ("denoise", FluxDenoiseStep), ("decode", FluxDecodeStep), ] ) +IMAGE2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", FluxTextEncoderStep), + ("image_encoder", FluxVaeEncoderStep), + ("input", FluxInputStep), + ("set_timesteps", FluxImg2ImgSetTimestepsStep), + ("prepare_latents", FluxImg2ImgPrepareLatentsStep), + ("denoise", FluxDenoiseStep), + ("decode", FluxDecodeStep), + ] +) AUTO_BLOCKS = InsertableDict( [ ("text_encoder", FluxTextEncoderStep), + ("image_encoder", FluxAutoVaeEncoderStep), ("before_denoise", FluxAutoBeforeDenoiseStep), ("denoise", FluxAutoDenoiseStep), ("decode", FluxAutoDecodeStep), @@ -122,4 +178,4 @@ def description(self): ) -ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "auto": AUTO_BLOCKS} +ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS} diff --git a/src/diffusers/modular_pipelines/flux/modular_pipeline.py b/src/diffusers/modular_pipelines/flux/modular_pipeline.py index 3cd5df0c70ee..e97445d411e4 100644 --- a/src/diffusers/modular_pipelines/flux/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/flux/modular_pipeline.py @@ -13,7 +13,7 @@ # limitations under the License. -from ...loaders import FluxLoraLoaderMixin +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from ...utils import logging from ..modular_pipeline import ModularPipeline @@ -21,7 +21,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin): +class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversionLoaderMixin): """ A ModularPipeline for Flux.