From 00dfe98aa8d806c85b70ca5f9d894fdc18b7a47c Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Mon, 28 Jul 2025 08:37:34 +0000 Subject: [PATCH 01/18] Add Bria model and pipeline to diffusers - Introduced `BriaTransformer2DModel` and `BriaPipeline` for enhanced image generation capabilities. - Updated import structures across various modules to include the new Bria components. - Added utility functions and output classes specific to the Bria pipeline. - Implemented tests for the Bria pipeline to ensure functionality and output integrity. --- src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/embeddings.py | 1 - src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_bria.py | 417 +++++++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/bria/__init__.py | 23 + src/diffusers/pipelines/bria/bria_utils.py | 464 ++++++++++++ src/diffusers/pipelines/bria/pipeline_bria.py | 665 ++++++++++++++++++ .../pipelines/bria/pipeline_output.py | 21 + tests/pipelines/bria/__init__.py | 0 tests/pipelines/bria/test_pipeline_bria.py | 372 ++++++++++ 12 files changed, 1971 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/models/transformers/transformer_bria.py create mode 100644 src/diffusers/pipelines/bria/__init__.py create mode 100644 src/diffusers/pipelines/bria/bria_utils.py create mode 100644 src/diffusers/pipelines/bria/pipeline_bria.py create mode 100644 src/diffusers/pipelines/bria/pipeline_output.py create mode 100644 tests/pipelines/bria/__init__.py create mode 100644 tests/pipelines/bria/test_pipeline_bria.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 80c78b8a96d5..9288f286858a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -179,6 +179,7 @@ "AutoencoderOobleck", "AutoencoderTiny", "AutoModel", + "BriaTransformer2DModel", "CacheMixin", "ChromaTransformer2DModel", "CogVideoXTransformer3DModel", @@ -392,6 +393,7 @@ "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", + "BriaPipeline", "ChromaImg2ImgPipeline", "ChromaPipeline", "CLIPImageProjection", @@ -835,6 +837,7 @@ AutoencoderOobleck, AutoencoderTiny, AutoModel, + BriaTransformer2DModel, CacheMixin, ChromaTransformer2DModel, CogVideoXTransformer3DModel, @@ -1023,6 +1026,7 @@ AudioLDM2UNet2DConditionModel, AudioLDMPipeline, AuraFlowPipeline, + BriaPipeline, ChromaImg2ImgPipeline, ChromaPipeline, CLIPImageProjection, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index cd1df3667a18..f095beb5ced1 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -75,6 +75,7 @@ _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] + _import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"] _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"] @@ -155,6 +156,7 @@ from .transformers import ( AllegroTransformer3DModel, AuraFlowTransformer2DModel, + BriaTransformer2DModel, ChromaTransformer2DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b51f5d7aec25..051a776e49fd 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1220,7 +1220,6 @@ def apply_rotary_emb( x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index dd8813369b5d..e998eefac0f7 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -17,6 +17,7 @@ from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .transformer_allegro import AllegroTransformer3DModel + from .transformer_bria import BriaTransformer2DModel from .transformer_chroma import ChromaTransformer2DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_cogview4 import CogView4Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py new file mode 100644 index 000000000000..92b413aa81ac --- /dev/null +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -0,0 +1,417 @@ +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous +from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from diffusers.pipelines.bria.bria_utils import FluxPosEmbed as EmbedND +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Timesteps(nn.Module): + def __init__( + self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000 + ): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + self.time_theta = time_theta + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + max_period=self.time_theta, + ) + return t_emb + + +class TimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, time_theta): + super().__init__() + + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta + ) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D) + return timesteps_emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio, allegro + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +class FluxPosEmbed(torch.nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + freqs_dtype = torch.float32 if is_mps else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +""" +Based on FluxPipeline with several changes: +- no pooled embeddings +- We use zero padding for prompts +- No guidance embedding since this is not a distilled version +""" + + +class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + The Transformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Parameters: + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. + num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = None, + guidance_embeds: bool = False, + axes_dims_rope: List[int] = [16, 56, 56], + rope_theta=10000, + time_theta=10000, + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope) + + self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) + + # if pooled_projection_dim: + # self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu") + + if guidance_embeds: + self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) + else: + guidance = None + + # temb = ( + # self.time_text_embed(timestep, pooled_projections) + # if guidance is None + # else self.time_text_embed(timestep, guidance, pooled_projections) + # ) + + temb = self.time_embed(timestep, dtype=hidden_states.dtype) + + # if pooled_projections: + # temb+=self.pooled_text_embed(pooled_projections) + + if guidance: + temb += self.guidance_embed(guidance, dtype=hidden_states.dtype) + + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if len(txt_ids.shape) == 2: + ids = torch.cat((txt_ids, img_ids), dim=0) + else: + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pos_embed(ids) + + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + # hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( + hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + controlnet_single_block_samples[index_block // interval_control] + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c8fbdf0c6c29..306b0675b562 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -127,6 +127,7 @@ "AnimateDiffVideoToVideoPipeline", "AnimateDiffVideoToVideoControlNetPipeline", ] + _import_structure["bria"] = ["BriaPipeline"] _import_structure["flux"] = [ "FluxControlPipeline", "FluxControlInpaintPipeline", @@ -546,6 +547,7 @@ ) from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline + from .bria import BriaPipeline from .chroma import ChromaImg2ImgPipeline, ChromaPipeline from .cogvideo import ( CogVideoXFunControlPipeline, diff --git a/src/diffusers/pipelines/bria/__init__.py b/src/diffusers/pipelines/bria/__init__.py new file mode 100644 index 000000000000..88e51b534ab0 --- /dev/null +++ b/src/diffusers/pipelines/bria/__init__.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, +) + + +_import_structure = { + "pipeline_bria": ["BriaPipeline"], +} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_bria import BriaPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/src/diffusers/pipelines/bria/bria_utils.py b/src/diffusers/pipelines/bria/bria_utils.py new file mode 100644 index 000000000000..fbec780333fc --- /dev/null +++ b/src/diffusers/pipelines/bria/bria_utils.py @@ -0,0 +1,464 @@ +import math +import os +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR +from transformers import ( + AutoTokenizer, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from diffusers.optimization import get_scheduler +from diffusers.utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_text(caption): + existing_text_list = set() + + if caption[0] == '"' and caption[-1] == '"': + caption = caption[1:-2] + + if caption[0] == "'" and caption[-1] == "'": + caption = caption[1:-2] + + text_list = [] + current_text = "" + text_present = False + for c in caption: + if c == '"' and not text_present: + text_present = True + continue + + if c == '"' and text_present: + if current_text not in existing_text_list: + text_list += [current_text] + existing_text_list.add(current_text) + + text_present = False + current_text = "" + continue + + if text_present: + current_text += c + + return text_list + + +def get_by_t5_prompt_embeds( + tokenizer: AutoTokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + max_sequence_length: int = 128, + device: Optional[torch.device] = None, +): + device = device or text_encoder.device + + if isinstance(prompt, list): + assert len(prompt) == 1 + prompt = prompt[0] + + assert type(prompt) == str + + captions_list = get_text(prompt) + embeddings_list = [] + for inner_prompt in captions_list: + text_inputs = tokenizer( + [inner_prompt], + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + embeddings_list += [prompt_embeds[0]] + + # No Text Found + if len(embeddings_list) == 0: + return None + + prompt_embeds = torch.concatenate(embeddings_list, axis=0) + + # Concat zeros to max_sequence + seq_len, dim = prompt_embeds.shape + if seq_len < max_sequence_length: + padding = torch.zeros( + (max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.concat([prompt_embeds, padding], dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + return prompt_embeds + + +def get_t5_prompt_embeds( + tokenizer: T5TokenizerFast, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, +): + device = device or text_encoder.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + # padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + # Concat zeros to max_sequence + b, seq_len, dim = prompt_embeds.shape + if seq_len < max_sequence_length: + padding = torch.zeros( + (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) + + prompt_embeds = prompt_embeds.to(device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + +# in order the get the same sigmas as in training and sample from them +def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + sigmas = timesteps / num_train_timesteps + + inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)] + new_sigmas = sigmas[inds] + return new_sigmas + + +def is_ng_none(negative_prompt): + return ( + negative_prompt is None + or negative_prompt == "" + or (isinstance(negative_prompt, list) and negative_prompt[0] is None) + or (type(negative_prompt) == list and negative_prompt[0] == "") + ) + + +class CudaTimerContext: + def __init__(self, times_arr): + self.times_arr = times_arr + + def __enter__(self): + self.before_event = torch.cuda.Event(enable_timing=True) + self.after_event = torch.cuda.Event(enable_timing=True) + self.before_event.record() + + def __exit__(self, type, value, traceback): + self.after_event.record() + torch.cuda.synchronize() + elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000 + self.times_arr.append(elapsed_time) + + +def get_env_prefix(): + env = os.environ.get("CLOUD_PROVIDER", "AWS").upper() + if env == "AWS": + return "SM_CHANNEL" + elif env == "AZURE": + return "AZUREML_DATAREFERENCE" + + raise Exception(f"Env {env} not supported") + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def initialize_distributed(): + # Initialize the process group for distributed training + dist.init_process_group("nccl") + + # Get the current process's rank (ID) and the total number of processes (world size) + rank = dist.get_rank() + world_size = dist.get_world_size() + + print(f"Initialized distributed training: Rank {rank}/{world_size}") + + +def get_clip_prompt_embeds( + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 77, + device: Optional[torch.device] = None, +): + device = device or text_encoder.device + assert max_sequence_length == tokenizer.model_max_length + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Define tokenizers and text encoders + tokenizers = [tokenizer, tokenizer_2] + text_encoders = [text_encoder, text_encoder_2] + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, pooled_prompt_embeds + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio, allegro + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +class FluxPosEmbed(torch.nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + freqs_dtype = torch.float32 if is_mps else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +# Not really cosine but with decay +def get_cosine_schedule_with_warmup_and_decay( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, + constant_steps=-1, + eps=1e-5, +) -> LambdaLR: + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_periods (`float`, *optional*, defaults to 0.5): + The number of periods of the cosine function in a schedule (the default is to just decrease from the max + value to 0 following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + constant_steps (`int`): + The total number of constant lr steps following a warmup + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + if constant_steps <= 0: + constant_steps = num_training_steps - num_warmup_steps + + def lr_lambda(current_step): + # Accelerate sends current_step*num_processes + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step < num_warmup_steps + constant_steps: + return 1 + + # print(f'Inside LR: num_training_steps:{num_training_steps}, current_step:{current_step}, num_warmup_steps: {num_warmup_steps}, constant_steps: {constant_steps}') + return max( + eps, + float(num_training_steps - current_step) + / float(max(1, num_training_steps - num_warmup_steps - constant_steps)), + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_lr_scheduler(name, optimizer, num_warmup_steps, num_training_steps, constant_steps): + if name != "constant_with_warmup_cosine_decay": + return get_scheduler( + name=name, optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps + ) + + # Usign custom warmup+cnstant+decay scheduler + return get_cosine_schedule_with_warmup_and_decay( + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + constant_steps=constant_steps, + ) diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py new file mode 100644 index 000000000000..1b9542f6ddd8 --- /dev/null +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -0,0 +1,665 @@ +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import ( + T5EncoderModel, + T5TokenizerFast, +) + +import diffusers +from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FluxLoraLoaderMixin +from diffusers.models.transformers.transformer_bria import BriaTransformer2DModel +from diffusers.pipelines.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none +from diffusers.pipelines.bria.pipeline_output import BriaPipelineOutput +from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusion3Pipeline + + >>> pipe = StableDiffusion3Pipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt).images[0] + >>> image.save("sd3.png") + ``` +""" + +T5_PRECISION = torch.float16 + +""" +Based on FluxPipeline with several changes: +- no pooled embeddings +- We use zero padding for prompts +- No guidance embedding since this is not a distilled version +""" + + +class BriaPipeline(FluxPipeline): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + def __init__( + self, + transformer: BriaTransformer2DModel, + scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers], + vae: AutoencoderKL, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + ): + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + # TODO - why different than offical flux (-1) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k + + # T5 is senstive to precision so we use the precision used for precompute and cast as needed + self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION) + for block in self.text_encoder.encoder.block: + block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) + + if self.vae.config.shift_factor is None: + self.vae.config.shift_factor = 0 + self.vae.to(dtype=torch.float32) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 128, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = get_t5_prompt_embeds( + self.tokenizer, + self.text_encoder, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ).to(dtype=self.transformer.dtype) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + if not is_ng_none(negative_prompt): + negative_prompt = ( + batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = get_t5_prompt_embeds( + self.tokenizer, + self.text_encoder, + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ).to(dtype=self.transformer.dtype) + else: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds, negative_prompt_embeds, text_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + clip_value: Union[None, float] = None, + normalize: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + (prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if ( + isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler) + and self.scheduler.config["use_dynamic_shifting"] + ): + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] # Shift by height - Why just height? + print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}") + + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + else: + # 4. Prepare timesteps + # Sample from training sigmas + if isinstance(self.scheduler, DDIMScheduler) or isinstance( + self.scheduler, EulerAncestralDiscreteScheduler + ): + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, None, None + ) + else: + sigmas = get_original_sigmas( + num_train_timesteps=self.scheduler.config.num_train_timesteps, + num_inference_steps=num_inference_steps, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Supprot different diffusers versions + + if len(latent_image_ids.shape) == 3: + latent_image_ids = latent_image_ids[0] + if len(text_ids.shape) == 3: + text_ids = text_ids[0] + + + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if type(self.scheduler) != FlowMatchEulerDiscreteScheduler: + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + cfg_noise_pred_text = noise_pred_text.std() + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if normalize: + noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred + + if clip_value: + assert clip_value > 0 + noise_pred = noise_pred.clip(-clip_value, clip_value) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return BriaPipelineOutput(images=image) + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def to(self, *args, **kwargs): + DiffusionPipeline.to(self, *args, **kwargs) + # T5 is senstive to precision so we use the precision used for precompute and cast as needed + self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION) + for block in self.text_encoder.encoder.block: + block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) + + if self.vae.config.shift_factor == 0 and self.vae.dtype != torch.float32: + self.vae.to(dtype=torch.float32) + + return self + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/bria/pipeline_output.py b/src/diffusers/pipelines/bria/pipeline_output.py new file mode 100644 index 000000000000..2cda68de292f --- /dev/null +++ b/src/diffusers/pipelines/bria/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class BriaPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/tests/pipelines/bria/__init__.py b/tests/pipelines/bria/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/bria/test_pipeline_bria.py b/tests/pipelines/bria/test_pipeline_bria.py new file mode 100644 index 000000000000..cfeafe3ea3f5 --- /dev/null +++ b/tests/pipelines/bria/test_pipeline_bria.py @@ -0,0 +1,372 @@ +# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest +import tempfile + +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from transformers import T5EncoderModel, T5TokenizerFast + +from diffusers import ( + AutoencoderKL, + BriaTransformer2DModel, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.pipelines.bria import BriaPipeline +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + nightly, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + + +# from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist +from tests.pipelines.test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist + +enable_full_determinism() + + + +class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = BriaPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"]) + batch_params = frozenset(["prompt"]) + test_xformers_attention = False + + # there is no xformers processor for Flux + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = BriaTransformer2DModel( + patch_size=1, + in_channels=16, + num_layers=1, + num_single_layers=1, + attention_head_dim=8, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=None, + axes_dims_rope=[0, 4, 4], + ) + + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D"], + latent_channels=4, + sample_size=32, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "negative_prompt": "bad, ugly", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 16, + "width": 16, + "max_sequence_length": 48, + "output_type": "np", + } + return inputs + + def test_bria_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + assert max_diff > 1e-6 + + + + + def test_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + def test_inference(self): + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + self.assertEqual(image.shape, (1, 32, 32, 3)) + expected_slice = np.array( + [0.5361328, 0.5253906, 0.5234375, 0.5292969, 0.5214844, 0.5185547, 0.5283203, 0.5205078, 0.519043] + ) + + max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) + self.assertLess(max_diff, 1e-4) + + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + # check that all modules are float32 except for text_encoder + for name, module in pipe.components.items(): + if not hasattr(module, "dtype"): + continue + + if name == "text_encoder": + self.assertEqual(module.dtype, torch.float16) + else: + self.assertEqual(module.dtype, torch.float32) + + pipe.to(torch.bfloat16) + + # check that all modules are bfloat16 except for text_encoder (float16) and vae (float32) + for name, module in pipe.components.items(): + if not hasattr(module, "dtype"): + continue + + if name == "text_encoder": + self.assertEqual(module.dtype, torch.float16) + elif name == "vae": + self.assertEqual(module.dtype, torch.float32) + else: + self.assertEqual(module.dtype, torch.bfloat16) + + + def test_bria_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(16, 16), (32, 32), (64, 64)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + + def test_torch_dtype_dict(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + torch_dtype_dict = {"transformer": torch.bfloat16, "default": torch.float16} + loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict) + + self.assertEqual(loaded_pipe.transformer.dtype, torch.bfloat16) + self.assertEqual(loaded_pipe.text_encoder.dtype, torch.float16) + self.assertEqual(loaded_pipe.vae.dtype, torch.float32) + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + torch_dtype_dict = {"default": torch.float16} + loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict) + + self.assertEqual(loaded_pipe.transformer.dtype, torch.float16) + self.assertEqual(loaded_pipe.text_encoder.dtype, torch.float16) + self.assertEqual(loaded_pipe.vae.dtype, torch.float32) + + + + +@slow +@require_torch_gpu +class BriaPipelineSlowTests(unittest.TestCase): + pipeline_class = BriaPipeline + repo_id = "briaai/BRIA-3.2" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, seed=0): + generator = torch.Generator(device="cpu").manual_seed(seed) + prompt_embeds = torch.load( + hf_hub_download( + repo_id="diffusers/test-slices", repo_type="dataset", filename="bria_prompt_embeds.pt" + ) + ).to(device) + return { + "prompt_embeds": prompt_embeds, + "num_inference_steps": 2, + "guidance_scale": 0.0, + "output_type": "np", + "generator": generator, + } + + def test_bria_inference_bf16(self): + pipe = self.pipeline_class.from_pretrained( + self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, tokenizer=None + ) + pipe.to(torch_device) + + inputs = self.get_inputs(torch_device) + + image = pipe(**inputs).images[0] + image_slice = image[0, :10, :10, 0].flatten() + + expected_slice = np.array( + [ + 0.3242, + 0.3203, + 0.3164, + 0.3164, + 0.3125, + 0.3125, + 0.3281, + 0.3242, + 0.3203, + 0.3301, + 0.3262, + 0.3242, + 0.3281, + 0.3242, + 0.3203, + 0.3262, + 0.3262, + 0.3164, + 0.3262, + 0.3281, + 0.3184, + 0.3281, + 0.3281, + 0.3203, + 0.3281, + 0.3281, + 0.3164, + 0.332, + 0.332, + 0.3203, + ] + ) + max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice) + self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}") + + +@nightly +@require_torch_gpu +class BriaPipelineNightlyTests(unittest.TestCase): + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_bria_inference(self): + pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16) + pipe.to(torch_device) + + prompt = "a close-up of a smiling cat, high quality, realistic" + image = pipe(prompt=prompt, num_inference_steps=5, output_type="np").images[0] + + image_slice = image[0, :10, :10, 0].flatten() + expected_slice = np.array( + [ + 0.668, + 0.668, + 0.6641, + 0.6602, + 0.6602, + 0.6562, + 0.6523, + 0.6484, + 0.6523, + 0.6562, + 0.668, + 0.668, + 0.6641, + 0.6641, + 0.6602, + 0.6562, + 0.6523, + 0.6484, + 0.6523, + 0.6562, + 0.668, + 0.668, + 0.668, + 0.6641, + 0.6602, + 0.6562, + 0.6523, + 0.6484, + 0.6523, + 0.6562, + ] + ) + + max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice) + self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}") + From 7808ee0c198d962a175fea022f9d12c56b2e6d51 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Tue, 29 Jul 2025 08:09:31 +0000 Subject: [PATCH 02/18] with working tests --- .../models/transformers/transformer_bria.py | 65 +++++-- src/diffusers/pipelines/bria/bria_utils.py | 57 +++--- src/diffusers/pipelines/bria/pipeline_bria.py | 18 +- .../test_models_transformer_bria.py | 184 ++++++++++++++++++ tests/pipelines/bria/test_pipeline_bria.py | 115 ++++++----- 5 files changed, 345 insertions(+), 94 deletions(-) create mode 100644 tests/models/transformers/test_models_transformer_bria.py diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py index 92b413aa81ac..12647341cccb 100644 --- a/src/diffusers/models/transformers/transformer_bria.py +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -13,6 +13,9 @@ from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock from diffusers.pipelines.bria.bria_utils import FluxPosEmbed as EmbedND from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers import __version__ as diffusers_version +from packaging import version + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -240,10 +243,10 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - + # def _set_gradient_checkpointing(self, module, enable=False): + # if hasattr(module, "gradient_checkpointing"): + # module.gradient_checkpointing = enable + def forward( self, hidden_states: torch.Tensor, @@ -321,11 +324,15 @@ def forward( temb += self.guidance_embed(guidance, dtype=hidden_states.dtype) encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if len(txt_ids.shape) == 3: + txt_ids = txt_ids[0] - if len(txt_ids.shape) == 2: - ids = torch.cat((txt_ids, img_ids), dim=0) - else: - ids = torch.cat((txt_ids, img_ids), dim=1) + if len(img_ids.shape) == 3: + img_ids = img_ids[0] + + + ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): @@ -364,7 +371,8 @@ def custom_forward(*inputs): interval_control = int(np.ceil(interval_control)) hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - # hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + if version.parse(diffusers_version) < version.parse("0.35.0.dev0"): + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if self.training and self.gradient_checkpointing: @@ -379,21 +387,38 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, - ) + if version.parse(diffusers_version) < version.parse("0.35.0.dev0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) else: - encoder_hidden_states, hidden_states = block( + if version.parse(diffusers_version) < version.parse("0.35.0.dev0"): + hidden_states = block( hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - ) + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) # controlnet residual if controlnet_single_block_samples is not None: @@ -403,6 +428,8 @@ def custom_forward(*inputs): hidden_states[:, encoder_hidden_states.shape[1] :, ...] + controlnet_single_block_samples[index_block // interval_control] ) + if version.parse(diffusers_version) < version.parse("0.35.0.dev0"): + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) diff --git a/src/diffusers/pipelines/bria/bria_utils.py b/src/diffusers/pipelines/bria/bria_utils.py index fbec780333fc..236d8eab9b26 100644 --- a/src/diffusers/pipelines/bria/bria_utils.py +++ b/src/diffusers/pipelines/bria/bria_utils.py @@ -114,42 +114,43 @@ def get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - - text_inputs = tokenizer( - prompt, - # padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" + prompt_embeds_list = [] + for p in prompt: + text_inputs = tokenizer( + p, + # padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - prompt_embeds = text_encoder(text_input_ids.to(device))[0] + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) - # Concat zeros to max_sequence - b, seq_len, dim = prompt_embeds.shape - if seq_len < max_sequence_length: - padding = torch.zeros( - (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device - ) - prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) + prompt_embeds = text_encoder(text_input_ids.to(device))[0] - prompt_embeds = prompt_embeds.to(device=device) + # Concat zeros to max_sequence + b, seq_len, dim = prompt_embeds.shape + if seq_len < max_sequence_length: + padding = torch.zeros( + (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) + prompt_embeds_list.append(prompt_embeds) - _, seq_len, _ = prompt_embeds.shape + prompt_embeds = torch.concat(prompt_embeds_list, dim=0) + prompt_embeds = prompt_embeds.to(device=device) # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1) return prompt_embeds diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py index 1b9542f6ddd8..6cb0d8663d85 100644 --- a/src/diffusers/pipelines/bria/pipeline_bria.py +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -3,6 +3,8 @@ import numpy as np import torch from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast, ) @@ -81,6 +83,9 @@ class BriaPipeline(FluxPipeline): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). """ + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, @@ -89,6 +94,8 @@ def __init__( vae: AutoencoderKL, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, ): self.register_modules( vae=vae, @@ -96,6 +103,8 @@ def __init__( tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, ) # TODO - why different than offical flux (-1) @@ -541,9 +550,12 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - + # if height % 8 != 0 or width % 8 != 0: + # raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): diff --git a/tests/models/transformers/test_models_transformer_bria.py b/tests/models/transformers/test_models_transformer_bria.py new file mode 100644 index 000000000000..0f18053c4537 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_bria.py @@ -0,0 +1,184 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import BriaTransformer2DModel +from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 +from diffusers.models.embeddings import ImageProjection +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + +def create_chroma_ip_adapter_state_dict(model): + # "ip_adapter" (cross-attention weights) + ip_cross_attn_state_dict = {} + key_id = 0 + + for name in model.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + continue + + joint_attention_dim = model.config["joint_attention_dim"] + hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] + sd = FluxIPAdapterJointAttnProcessor2_0( + hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], + f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], + } + ) + + key_id += 1 + + # "image_proj" (ImageProjection layer weights) + + image_projection = ImageProjection( + cross_attention_dim=model.config["joint_attention_dim"], + image_embed_dim=model.config["pooled_projection_dim"], + num_image_text_embeds=4, + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.weight": sd["image_embeds.weight"], + "proj.bias": sd["image_embeds.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + +class BriaTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = BriaTransformer2DModel + main_input_name = "hidden_states" + # We override the items here because the transformer under consideration is small. + model_split_percents = [0.8, 0.7, 0.7] + + # Skip setting testing with default: AttnProcessor + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) + image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + "timestep": timestep, + } + + @property + def input_shape(self): + return (16, 4) + + @property + def output_shape(self): + return (16, 4) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "pooled_projection_dim": None, + "axes_dims_rope": [0, 4, 4], + + } + + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_deprecated_inputs_img_txt_ids_3d(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output_1 = model(**inputs_dict).to_tuple()[0] + + # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) + text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) + image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) + + assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" + assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" + + inputs_dict["txt_ids"] = text_ids_3d + inputs_dict["img_ids"] = image_ids_3d + + with torch.no_grad(): + output_2 = model(**inputs_dict).to_tuple()[0] + + self.assertEqual(output_1.shape, output_2.shape) + self.assertTrue( + torch.allclose(output_1, output_2, atol=1e-5), + msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", + ) + + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"BriaTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class BriaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = BriaTransformer2DModel + + def prepare_init_args_and_inputs_for_common(self): + return BriaTransformerTests().prepare_init_args_and_inputs_for_common() + + +class BriaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): + model_class = BriaTransformer2DModel + + def prepare_init_args_and_inputs_for_common(self): + return BriaTransformerTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/bria/test_pipeline_bria.py b/tests/pipelines/bria/test_pipeline_bria.py index cfeafe3ea3f5..8cecbe2eb01e 100644 --- a/tests/pipelines/bria/test_pipeline_bria.py +++ b/tests/pipelines/bria/test_pipeline_bria.py @@ -33,13 +33,14 @@ nightly, numpy_cosine_similarity_distance, require_torch_gpu, + require_accelerator, slow, torch_device, ) # from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist -from tests.pipelines.test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist +from tests.pipelines.test_pipelines_common import PipelineTesterMixin,to_np enable_full_determinism() @@ -72,13 +73,19 @@ def get_dummy_components(self): torch.manual_seed(0) vae = AutoencoderKL( - block_out_channels=[32], + act_fn="silu", + block_out_channels=(32,), in_channels=3, out_channels=3, down_block_types=["DownEncoderBlock2D"], up_block_types=["UpDecoderBlock2D"], latent_channels=4, sample_size=32, + shift_factor=0, + scaling_factor=0.13025, + use_post_quant_conv=True, + use_quant_conv=True, + force_upcast=False, ) scheduler = FlowMatchEulerDiscreteScheduler() @@ -93,6 +100,8 @@ def get_dummy_components(self): "tokenizer": tokenizer, "transformer": transformer, "vae": vae, + "image_encoder": None, + "feature_extractor": None, } return components @@ -114,7 +123,8 @@ def get_dummy_inputs(self, device, seed=0): "output_type": "np", } return inputs - + def test_encode_prompt_works_in_isolation(self): + pass def test_bria_different_prompts(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) @@ -142,53 +152,51 @@ def test_image_output_shape(self): output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) - def test_inference(self): - device = "cpu" + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator + def test_save_load_float16(self, expected_max_diff=1e-2): components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - image_slice = image[0, -3:, -3:, -1] - - self.assertEqual(image.shape, (1, 32, 32, 3)) - expected_slice = np.array( - [0.5361328, 0.5253906, 0.5234375, 0.5292969, 0.5214844, 0.5185547, 0.5283203, 0.5205078, 0.519043] - ) - - max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) - self.assertLess(max_diff, 1e-4) + for name, module in components.items(): + if hasattr(module, "half"): + components[name] = module.to(torch_device).half() - def test_to_dtype(self): - components = self.get_dummy_components() pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - # check that all modules are float32 except for text_encoder - for name, module in pipe.components.items(): - if not hasattr(module, "dtype"): + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for name, component in pipe_loaded.components.items(): + if name == "vae": continue + if hasattr(component, "dtype"): + self.assertTrue( + component.dtype == torch.float16, + f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", + ) - if name == "text_encoder": - self.assertEqual(module.dtype, torch.float16) - else: - self.assertEqual(module.dtype, torch.float32) - - pipe.to(torch.bfloat16) - - # check that all modules are bfloat16 except for text_encoder (float16) and vae (float32) - for name, module in pipe.components.items(): - if not hasattr(module, "dtype"): - continue + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess( + max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." + ) + + - if name == "text_encoder": - self.assertEqual(module.dtype, torch.float16) - elif name == "vae": - self.assertEqual(module.dtype, torch.float32) - else: - self.assertEqual(module.dtype, torch.bfloat16) def test_bria_image_output_shape(self): @@ -205,6 +213,13 @@ def test_bria_image_output_shape(self): output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + self.assertTrue([dtype == torch.float32 for dtype in model_dtypes] == [False,True,True]) def test_torch_dtype_dict(self): components = self.get_dummy_components() @@ -217,7 +232,7 @@ def test_torch_dtype_dict(self): self.assertEqual(loaded_pipe.transformer.dtype, torch.bfloat16) self.assertEqual(loaded_pipe.text_encoder.dtype, torch.float16) - self.assertEqual(loaded_pipe.vae.dtype, torch.float32) + self.assertEqual(loaded_pipe.vae.dtype, torch.float16) with tempfile.TemporaryDirectory() as tmpdirname: pipe.save_pretrained(tmpdirname) @@ -226,7 +241,7 @@ def test_torch_dtype_dict(self): self.assertEqual(loaded_pipe.transformer.dtype, torch.float16) self.assertEqual(loaded_pipe.text_encoder.dtype, torch.float16) - self.assertEqual(loaded_pipe.vae.dtype, torch.float32) + self.assertEqual(loaded_pipe.vae.dtype, torch.float16) @@ -309,6 +324,18 @@ def test_bria_inference_bf16(self): ) max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice) self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}") + + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(dtype=torch.float16) + model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) @nightly From c58267e60023beb4d46c73112439a7c9ec91f8b0 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Tue, 29 Jul 2025 08:15:19 +0000 Subject: [PATCH 03/18] style and quality pass --- .../models/transformers/transformer_bria.py | 16 ++++----- src/diffusers/pipelines/bria/pipeline_bria.py | 4 +-- .../test_models_transformer_bria.py | 5 +-- tests/pipelines/bria/test_pipeline_bria.py | 33 +++++++------------ 4 files changed, 20 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py index 12647341cccb..8741bd022c65 100644 --- a/src/diffusers/models/transformers/transformer_bria.py +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -3,7 +3,9 @@ import numpy as np import torch import torch.nn as nn +from packaging import version +from diffusers import __version__ as diffusers_version from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding @@ -13,9 +15,6 @@ from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock from diffusers.pipelines.bria.bria_utils import FluxPosEmbed as EmbedND from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from diffusers import __version__ as diffusers_version -from packaging import version - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -246,7 +245,7 @@ def __init__( # def _set_gradient_checkpointing(self, module, enable=False): # if hasattr(module, "gradient_checkpointing"): # module.gradient_checkpointing = enable - + def forward( self, hidden_states: torch.Tensor, @@ -324,14 +323,13 @@ def forward( temb += self.guidance_embed(guidance, dtype=hidden_states.dtype) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - + if len(txt_ids.shape) == 3: txt_ids = txt_ids[0] if len(img_ids.shape) == 3: img_ids = img_ids[0] - ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) @@ -408,9 +406,9 @@ def custom_forward(*inputs): else: if version.parse(diffusers_version) < version.parse("0.35.0.dev0"): hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, ) else: encoder_hidden_states, hidden_states = block( diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py index 6cb0d8663d85..05c06e757e18 100644 --- a/src/diffusers/pipelines/bria/pipeline_bria.py +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -9,7 +9,6 @@ T5TokenizerFast, ) -import diffusers from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxLoraLoaderMixin @@ -83,6 +82,7 @@ class BriaPipeline(FluxPipeline): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). """ + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds"] @@ -455,8 +455,6 @@ def __call__( latent_image_ids = latent_image_ids[0] if len(text_ids.shape) == 3: text_ids = text_ids[0] - - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: diff --git a/tests/models/transformers/test_models_transformer_bria.py b/tests/models/transformers/test_models_transformer_bria.py index 0f18053c4537..d95e2950f91d 100644 --- a/tests/models/transformers/test_models_transformer_bria.py +++ b/tests/models/transformers/test_models_transformer_bria.py @@ -120,7 +120,6 @@ def output_shape(self): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "patch_size": 1, "in_channels": 4, "num_layers": 1, @@ -130,7 +129,6 @@ def prepare_init_args_and_inputs_for_common(self): "joint_attention_dim": 32, "pooled_projection_dim": None, "axes_dims_rope": [0, 4, 4], - } inputs_dict = self.dummy_input @@ -163,8 +161,7 @@ def test_deprecated_inputs_img_txt_ids_3d(self): torch.allclose(output_1, output_2, atol=1e-5), msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", ) - - + def test_gradient_checkpointing_is_applied(self): expected_set = {"BriaTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/bria/test_pipeline_bria.py b/tests/pipelines/bria/test_pipeline_bria.py index 8cecbe2eb01e..c85ba045371a 100644 --- a/tests/pipelines/bria/test_pipeline_bria.py +++ b/tests/pipelines/bria/test_pipeline_bria.py @@ -13,8 +13,8 @@ # limitations under the License. import gc -import unittest import tempfile +import unittest import numpy as np import torch @@ -32,18 +32,17 @@ enable_full_determinism, nightly, numpy_cosine_similarity_distance, - require_torch_gpu, require_accelerator, + require_torch_gpu, slow, torch_device, ) - # from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist -from tests.pipelines.test_pipelines_common import PipelineTesterMixin,to_np +from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np -enable_full_determinism() +enable_full_determinism() class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @@ -56,7 +55,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False test_layerwise_casting = True test_group_offloading = True - + def get_dummy_components(self): torch.manual_seed(0) transformer = BriaTransformer2DModel( @@ -123,8 +122,10 @@ def get_dummy_inputs(self, device, seed=0): "output_type": "np", } return inputs + def test_encode_prompt_works_in_isolation(self): pass + def test_bria_different_prompts(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) @@ -135,9 +136,6 @@ def test_bria_different_prompts(self): max_diff = np.abs(output_same_prompt - output_different_prompts).max() assert max_diff > 1e-6 - - - def test_image_output_shape(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) @@ -194,11 +192,7 @@ def test_save_load_float16(self, expected_max_diff=1e-2): self.assertLess( max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." ) - - - - def test_bria_image_output_shape(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) @@ -217,9 +211,9 @@ def test_to_dtype(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.set_progress_bar_config(disable=None) - + model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] - self.assertTrue([dtype == torch.float32 for dtype in model_dtypes] == [False,True,True]) + self.assertTrue([dtype == torch.float32 for dtype in model_dtypes] == [False, True, True]) def test_torch_dtype_dict(self): components = self.get_dummy_components() @@ -243,8 +237,6 @@ def test_torch_dtype_dict(self): self.assertEqual(loaded_pipe.text_encoder.dtype, torch.float16) self.assertEqual(loaded_pipe.vae.dtype, torch.float16) - - @slow @require_torch_gpu @@ -265,9 +257,7 @@ def tearDown(self): def get_inputs(self, device, seed=0): generator = torch.Generator(device="cpu").manual_seed(seed) prompt_embeds = torch.load( - hf_hub_download( - repo_id="diffusers/test-slices", repo_type="dataset", filename="bria_prompt_embeds.pt" - ) + hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="bria_prompt_embeds.pt") ).to(device) return { "prompt_embeds": prompt_embeds, @@ -324,7 +314,7 @@ def test_bria_inference_bf16(self): ) max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice) self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}") - + def test_to_dtype(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -396,4 +386,3 @@ def test_bria_inference(self): max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice) self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}") - From 7c1cf7ef3cf4637942c652723e5122a6a676c16f Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Tue, 29 Jul 2025 10:48:30 +0000 Subject: [PATCH 04/18] adding docs --- docs/source/en/_toctree.yml | 4 ++ docs/source/en/api/models/bria_transformer.md | 19 ++++++++ docs/source/en/api/pipelines/bria_3_2.md | 48 +++++++++++++++++++ 3 files changed, 71 insertions(+) create mode 100644 docs/source/en/api/models/bria_transformer.md create mode 100644 docs/source/en/api/pipelines/bria_3_2.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b095b2cc1a73..3e0b7c5c9641 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -326,6 +326,8 @@ title: AllegroTransformer3DModel - local: api/models/aura_flow_transformer2d title: AuraFlowTransformer2DModel + - local: api/models/bria_transformer2d + title: BriaTransformer2DModel - local: api/models/chroma_transformer title: ChromaTransformer2DModel - local: api/models/cogvideox_transformer3d @@ -450,6 +452,8 @@ title: AutoPipeline - local: api/pipelines/blip_diffusion title: BLIP-Diffusion + - local: api/pipelines/bria_3_2 + title: Bria 3.2 - local: api/pipelines/chroma title: Chroma - local: api/pipelines/cogvideox diff --git a/docs/source/en/api/models/bria_transformer.md b/docs/source/en/api/models/bria_transformer.md new file mode 100644 index 000000000000..9df7eeb6ffcd --- /dev/null +++ b/docs/source/en/api/models/bria_transformer.md @@ -0,0 +1,19 @@ + + +# BriaTransformer2DModel + +A modified flux Transformer model from [Bria](https://huggingface.co/briaai/BRIA-3.2) + +## BriaTransformer2DModel + +[[autodoc]] BriaTransformer2DModel diff --git a/docs/source/en/api/pipelines/bria_3_2.md b/docs/source/en/api/pipelines/bria_3_2.md new file mode 100644 index 000000000000..55689c7c9623 --- /dev/null +++ b/docs/source/en/api/pipelines/bria_3_2.md @@ -0,0 +1,48 @@ + + +# Bria 3.2 + +Bria 3.2 is the next-generation commercial-ready text-to-image model. With just 4 billion parameters, it provides exceptional aesthetics and text rendering, evaluated to provide on par results to leading open-source models, and outperforming other licensed models. +In addition to being built entirely on licensed data, 3.2 provides several advantages for enterprise and commercial use: + +- Efficient Compute - the model is X3 smaller than the equivalent models in the market (4B parameters vs 12B parameters other open source models) +- Architecture Consistency: Same architecture as 3.1—ideal for users looking to upgrade without disruption. +- Fine-tuning Speedup: 2x faster fine-tuning on L40S and A100. + +Original model checkpoints for Bria 3.2 can be found [here](https://huggingface.co/briaai/BRIA-3.2). + + +## Inference + +The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma). + +```python +import torch +from diffusers import BriaPipeline + +pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2",revision="bria_3_2_diffusers", torch_dtype=torch.bfloat16) +pipe.enable_model_cpu_offload() + +prompt = "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done." + +image = pipe(prompt).images[0] +image.save("bria.png") +``` + + +## BriaPipeline + +[[autodoc]] BriaPipeline + - all + - __call__ + From 92671ab8eabd46377eaed6f447da0f60ea905362 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Tue, 29 Jul 2025 10:52:01 +0000 Subject: [PATCH 05/18] add to overview --- docs/source/en/api/pipelines/overview.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index 4e7a4e5e8da2..f34262d37ce0 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -37,6 +37,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [AudioLDM2](audioldm2) | text2audio | | [AuraFlow](auraflow) | text2image | | [BLIP Diffusion](blip_diffusion) | text2image | +| [Bria 3.2](bria_3_2) | text2image | | [CogVideoX](cogvideox) | text2video | | [Consistency Models](consistency_models) | unconditional image generation | | [ControlNet](controlnet) | text2image, image2image, inpainting | From be296312c337c3fad355d732e7cf585aba629934 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Wed, 30 Jul 2025 11:31:27 +0000 Subject: [PATCH 06/18] fixes from "make fix-copies" --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 901aec4b2205..09a431716b0e 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -498,6 +498,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BriaTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CacheMixin(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 7538635c808e..ddd5c28448b1 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -332,6 +332,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class BriaPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ChromaImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 88c8e425f9170ae6f8cd30cf8e51e9b62e2d56e1 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Wed, 30 Jul 2025 14:28:23 +0000 Subject: [PATCH 07/18] Refactor transformer_bria.py and pipeline_bria.py: Introduce new EmbedND class for rotary position embedding, and enhance Timestep and TimestepProjEmbeddings classes. Add utility functions for handling negative prompts and generating original sigmas in pipeline_bria.py. --- .../models/transformers/transformer_bria.py | 107 +++++++++++------- src/diffusers/pipelines/bria/pipeline_bria.py | 73 +++++++++++- 2 files changed, 140 insertions(+), 40 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py index 8741bd022c65..8e2a023b184a 100644 --- a/src/diffusers/models/transformers/transformer_bria.py +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -13,51 +13,12 @@ from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNormContinuous from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock -from diffusers.pipelines.bria.bria_utils import FluxPosEmbed as EmbedND from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class Timesteps(nn.Module): - def __init__( - self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000 - ): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - self.scale = scale - self.time_theta = time_theta - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - scale=self.scale, - max_period=self.time_theta, - ) - return t_emb - - -class TimestepProjEmbeddings(nn.Module): - def __init__(self, embedding_dim, time_theta): - super().__init__() - - self.time_proj = Timesteps( - num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta - ) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - - def forward(self, timestep, dtype): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D) - return timesteps_emb - - def get_1d_rotary_pos_embed( dim: int, pos: Union[np.ndarray, int], @@ -124,6 +85,74 @@ def get_1d_rotary_pos_embed( return freqs_cis +class EmbedND(torch.nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + freqs_dtype = torch.float32 if is_mps else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Timesteps(nn.Module): + def __init__( + self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000 + ): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + self.time_theta = time_theta + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + max_period=self.time_theta, + ) + return t_emb + + +class TimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, time_theta): + super().__init__() + + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta + ) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D) + return timesteps_emb + + class FluxPosEmbed(torch.nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py index 05c06e757e18..26ae59a6b7d9 100644 --- a/src/diffusers/pipelines/bria/pipeline_bria.py +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -13,7 +13,6 @@ from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxLoraLoaderMixin from diffusers.models.transformers.transformer_bria import BriaTransformer2DModel -from diffusers.pipelines.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none from diffusers.pipelines.bria.pipeline_output import BriaPipelineOutput from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps from diffusers.pipelines.pipeline_utils import DiffusionPipeline @@ -57,6 +56,78 @@ T5_PRECISION = torch.float16 + +def is_ng_none(negative_prompt): + return ( + negative_prompt is None + or negative_prompt == "" + or (isinstance(negative_prompt, list) and negative_prompt[0] is None) + or (type(negative_prompt) == list and negative_prompt[0] == "") + ) + + +def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + sigmas = timesteps / num_train_timesteps + + inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)] + new_sigmas = sigmas[inds] + return new_sigmas + + +def get_t5_prompt_embeds( + tokenizer: T5TokenizerFast, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, +): + device = device or text_encoder.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + prompt_embeds_list = [] + for p in prompt: + text_inputs = tokenizer( + p, + # padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + # Concat zeros to max_sequence + b, seq_len, dim = prompt_embeds.shape + if seq_len < max_sequence_length: + padding = torch.zeros( + (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=0) + prompt_embeds = prompt_embeds.to(device=device) + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1) + + return prompt_embeds + + """ Based on FluxPipeline with several changes: - no pooled embeddings From 6cefe44e4bb1ce01d361b60e5b489829bf0526dd Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Thu, 31 Jul 2025 08:38:40 +0000 Subject: [PATCH 08/18] remove redundent and duplicates tests and fix bf16 slow test --- src/diffusers/pipelines/bria/bria_utils.py | 465 ------------------ src/diffusers/pipelines/bria/pipeline_bria.py | 14 +- tests/pipelines/bria/test_pipeline_bria.py | 114 +---- 3 files changed, 17 insertions(+), 576 deletions(-) delete mode 100644 src/diffusers/pipelines/bria/bria_utils.py diff --git a/src/diffusers/pipelines/bria/bria_utils.py b/src/diffusers/pipelines/bria/bria_utils.py deleted file mode 100644 index 236d8eab9b26..000000000000 --- a/src/diffusers/pipelines/bria/bria_utils.py +++ /dev/null @@ -1,465 +0,0 @@ -import math -import os -from typing import List, Optional, Union - -import numpy as np -import torch -import torch.distributed as dist -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR -from transformers import ( - AutoTokenizer, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - T5EncoderModel, - T5TokenizerFast, -) - -from diffusers.optimization import get_scheduler -from diffusers.utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def get_text(caption): - existing_text_list = set() - - if caption[0] == '"' and caption[-1] == '"': - caption = caption[1:-2] - - if caption[0] == "'" and caption[-1] == "'": - caption = caption[1:-2] - - text_list = [] - current_text = "" - text_present = False - for c in caption: - if c == '"' and not text_present: - text_present = True - continue - - if c == '"' and text_present: - if current_text not in existing_text_list: - text_list += [current_text] - existing_text_list.add(current_text) - - text_present = False - current_text = "" - continue - - if text_present: - current_text += c - - return text_list - - -def get_by_t5_prompt_embeds( - tokenizer: AutoTokenizer, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]], - max_sequence_length: int = 128, - device: Optional[torch.device] = None, -): - device = device or text_encoder.device - - if isinstance(prompt, list): - assert len(prompt) == 1 - prompt = prompt[0] - - assert type(prompt) == str - - captions_list = get_text(prompt) - embeddings_list = [] - for inner_prompt in captions_list: - text_inputs = tokenizer( - [inner_prompt], - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - prompt_embeds = text_encoder(text_input_ids.to(device))[0] - embeddings_list += [prompt_embeds[0]] - - # No Text Found - if len(embeddings_list) == 0: - return None - - prompt_embeds = torch.concatenate(embeddings_list, axis=0) - - # Concat zeros to max_sequence - seq_len, dim = prompt_embeds.shape - if seq_len < max_sequence_length: - padding = torch.zeros( - (max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device - ) - prompt_embeds = torch.concat([prompt_embeds, padding], dim=0) - - prompt_embeds = prompt_embeds.to(device=device) - return prompt_embeds - - -def get_t5_prompt_embeds( - tokenizer: T5TokenizerFast, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 128, - device: Optional[torch.device] = None, -): - device = device or text_encoder.device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - prompt_embeds_list = [] - for p in prompt: - text_inputs = tokenizer( - p, - # padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device))[0] - - # Concat zeros to max_sequence - b, seq_len, dim = prompt_embeds.shape - if seq_len < max_sequence_length: - padding = torch.zeros( - (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device - ) - prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=0) - prompt_embeds = prompt_embeds.to(device=device) - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1) - - return prompt_embeds - - -# in order the get the same sigmas as in training and sample from them -def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000): - timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() - sigmas = timesteps / num_train_timesteps - - inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)] - new_sigmas = sigmas[inds] - return new_sigmas - - -def is_ng_none(negative_prompt): - return ( - negative_prompt is None - or negative_prompt == "" - or (isinstance(negative_prompt, list) and negative_prompt[0] is None) - or (type(negative_prompt) == list and negative_prompt[0] == "") - ) - - -class CudaTimerContext: - def __init__(self, times_arr): - self.times_arr = times_arr - - def __enter__(self): - self.before_event = torch.cuda.Event(enable_timing=True) - self.after_event = torch.cuda.Event(enable_timing=True) - self.before_event.record() - - def __exit__(self, type, value, traceback): - self.after_event.record() - torch.cuda.synchronize() - elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000 - self.times_arr.append(elapsed_time) - - -def get_env_prefix(): - env = os.environ.get("CLOUD_PROVIDER", "AWS").upper() - if env == "AWS": - return "SM_CHANNEL" - elif env == "AZURE": - return "AZUREML_DATAREFERENCE" - - raise Exception(f"Env {env} not supported") - - -def compute_density_for_timestep_sampling( - weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None -): - """Compute the density for sampling the timesteps when doing SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") - u = torch.nn.functional.sigmoid(u) - elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device="cpu") - u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - else: - u = torch.rand(size=(batch_size,), device="cpu") - return u - - -def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): - """Computes loss weighting scheme for SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "sigma_sqrt": - weighting = (sigmas**-2.0).float() - elif weighting_scheme == "cosmap": - bot = 1 - 2 * sigmas + 2 * sigmas**2 - weighting = 2 / (math.pi * bot) - else: - weighting = torch.ones_like(sigmas) - return weighting - - -def initialize_distributed(): - # Initialize the process group for distributed training - dist.init_process_group("nccl") - - # Get the current process's rank (ID) and the total number of processes (world size) - rank = dist.get_rank() - world_size = dist.get_world_size() - - print(f"Initialized distributed training: Rank {rank}/{world_size}") - - -def get_clip_prompt_embeds( - text_encoder: CLIPTextModel, - text_encoder_2: CLIPTextModelWithProjection, - tokenizer: CLIPTokenizer, - tokenizer_2: CLIPTokenizer, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 77, - device: Optional[torch.device] = None, -): - device = device or text_encoder.device - assert max_sequence_length == tokenizer.model_max_length - prompt = [prompt] if isinstance(prompt, str) else prompt - - # Define tokenizers and text encoders - tokenizers = [tokenizer, tokenizer_2] - text_encoders = [text_encoder, text_encoder_2] - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - return prompt_embeds, pooled_prompt_embeds - - -def get_1d_rotary_pos_embed( - dim: int, - pos: Union[np.ndarray, int], - theta: float = 10000.0, - use_real=False, - linear_factor=1.0, - ntk_factor=1.0, - repeat_interleave_real=True, - freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) -): - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end - index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 - data type. - - Args: - dim (`int`): Dimension of the frequency tensor. - pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar - theta (`float`, *optional*, defaults to 10000.0): - Scaling factor for frequency computation. Defaults to 10000.0. - use_real (`bool`, *optional*): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. - linear_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for the context extrapolation. Defaults to 1.0. - ntk_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. - repeat_interleave_real (`bool`, *optional*, defaults to `True`): - If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. - Otherwise, they are concateanted with themselves. - freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): - the dtype of the frequency tensor. - Returns: - `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] - """ - assert dim % 2 == 0 - - if isinstance(pos, int): - pos = torch.arange(pos) - if isinstance(pos, np.ndarray): - pos = torch.from_numpy(pos) # type: ignore # [S] - - theta = theta * ntk_factor - freqs = ( - 1.0 - / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) - / linear_factor - ) # [D/2] - freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] - if use_real and repeat_interleave_real: - # flux, hunyuan-dit, cogvideox - freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] - return freqs_cos, freqs_sin - elif use_real: - # stable audio, allegro - freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] - freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] - return freqs_cos, freqs_sin - else: - # lumina - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] - return freqs_cis - - -class FluxPosEmbed(torch.nn.Module): - # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 - def __init__(self, theta: int, axes_dim: List[int]): - super().__init__() - self.theta = theta - self.axes_dim = axes_dim - - def forward(self, ids: torch.Tensor) -> torch.Tensor: - n_axes = ids.shape[-1] - cos_out = [] - sin_out = [] - pos = ids.float() - is_mps = ids.device.type == "mps" - freqs_dtype = torch.float32 if is_mps else torch.float64 - for i in range(n_axes): - cos, sin = get_1d_rotary_pos_embed( - self.axes_dim[i], - pos[:, i], - theta=self.theta, - repeat_interleave_real=True, - use_real=True, - freqs_dtype=freqs_dtype, - ) - cos_out.append(cos) - sin_out.append(sin) - freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) - freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) - return freqs_cos, freqs_sin - - -# Not really cosine but with decay -def get_cosine_schedule_with_warmup_and_decay( - optimizer: Optimizer, - num_warmup_steps: int, - num_training_steps: int, - num_cycles: float = 0.5, - last_epoch: int = -1, - constant_steps=-1, - eps=1e-5, -) -> LambdaLR: - """ - Create a schedule with a learning rate that decreases following the values of the cosine function between the - initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the - initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - num_periods (`float`, *optional*, defaults to 0.5): - The number of periods of the cosine function in a schedule (the default is to just decrease from the max - value to 0 following a half-cosine). - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - constant_steps (`int`): - The total number of constant lr steps following a warmup - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - if constant_steps <= 0: - constant_steps = num_training_steps - num_warmup_steps - - def lr_lambda(current_step): - # Accelerate sends current_step*num_processes - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - elif current_step < num_warmup_steps + constant_steps: - return 1 - - # print(f'Inside LR: num_training_steps:{num_training_steps}, current_step:{current_step}, num_warmup_steps: {num_warmup_steps}, constant_steps: {constant_steps}') - return max( - eps, - float(num_training_steps - current_step) - / float(max(1, num_training_steps - num_warmup_steps - constant_steps)), - ) - - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def get_lr_scheduler(name, optimizer, num_warmup_steps, num_training_steps, constant_steps): - if name != "constant_with_warmup_cosine_decay": - return get_scheduler( - name=name, optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps - ) - - # Usign custom warmup+cnstant+decay scheduler - return get_cosine_schedule_with_warmup_and_decay( - optimizer=optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - constant_steps=constant_steps, - ) diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py index 26ae59a6b7d9..84571cc1f295 100644 --- a/src/diffusers/pipelines/bria/pipeline_bria.py +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -186,9 +186,10 @@ def __init__( self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k # T5 is senstive to precision so we use the precision used for precompute and cast as needed - self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION) - for block in self.text_encoder.encoder.block: - block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) + if self.text_encoder is not None: + self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION) + for block in self.text_encoder.encoder.block: + block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) if self.vae.config.shift_factor is None: self.vae.config.shift_factor = 0 @@ -664,9 +665,10 @@ def check_inputs( def to(self, *args, **kwargs): DiffusionPipeline.to(self, *args, **kwargs) # T5 is senstive to precision so we use the precision used for precompute and cast as needed - self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION) - for block in self.text_encoder.encoder.block: - block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) + if self.text_encoder is not None: + self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION) + for block in self.text_encoder.encoder.block: + block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) if self.vae.config.shift_factor == 0 and self.vae.dtype != torch.float32: self.vae.to(dtype=torch.float32) diff --git a/tests/pipelines/bria/test_pipeline_bria.py b/tests/pipelines/bria/test_pipeline_bria.py index c85ba045371a..26c1771ee94b 100644 --- a/tests/pipelines/bria/test_pipeline_bria.py +++ b/tests/pipelines/bria/test_pipeline_bria.py @@ -256,17 +256,20 @@ def tearDown(self): def get_inputs(self, device, seed=0): generator = torch.Generator(device="cpu").manual_seed(seed) + prompt_embeds = torch.load( - hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="bria_prompt_embeds.pt") - ).to(device) + hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt") + ).to(torch_device) + return { "prompt_embeds": prompt_embeds, "num_inference_steps": 2, "guidance_scale": 0.0, + "max_sequence_length": 256, "output_type": "np", "generator": generator, } - + def test_bria_inference_bf16(self): pipe = self.pipeline_class.from_pretrained( self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, tokenizer=None @@ -276,113 +279,14 @@ def test_bria_inference_bf16(self): inputs = self.get_inputs(torch_device) image = pipe(**inputs).images[0] - image_slice = image[0, :10, :10, 0].flatten() + image_slice = image[0, :10, :10].flatten() expected_slice = np.array( - [ - 0.3242, - 0.3203, - 0.3164, - 0.3164, - 0.3125, - 0.3125, - 0.3281, - 0.3242, - 0.3203, - 0.3301, - 0.3262, - 0.3242, - 0.3281, - 0.3242, - 0.3203, - 0.3262, - 0.3262, - 0.3164, - 0.3262, - 0.3281, - 0.3184, - 0.3281, - 0.3281, - 0.3203, - 0.3281, - 0.3281, - 0.3164, - 0.332, - 0.332, - 0.3203, - ] + [0.59729785, 0.6153719, 0.595112, 0.5884763, 0.59366125, 0.5795311, 0.58325, 0.58449626, 0.57737637, 0.58432233, 0.5867875, 0.57824117, 0.5819089, 0.5830988, 0.57730293, 0.57647324, 0.5769151, 0.57312685, 0.57926565, 0.5823928, 0.57783926, 0.57162863, 0.575649, 0.5745547, 0.5740556, 0.5799735, 0.57799566, 0.5715559, 0.5771242, 0.5773058], + dtype=np.float32, ) max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice) self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}") - def test_to_dtype(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.set_progress_bar_config(disable=None) - - model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] - self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - - pipe.to(dtype=torch.float16) - model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] - self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) - - -@nightly -@require_torch_gpu -class BriaPipelineNightlyTests(unittest.TestCase): - def setUp(self): - super().setUp() - gc.collect() - backend_empty_cache(torch_device) - - def tearDown(self): - super().tearDown() - gc.collect() - backend_empty_cache(torch_device) - - def test_bria_inference(self): - pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16) - pipe.to(torch_device) - prompt = "a close-up of a smiling cat, high quality, realistic" - image = pipe(prompt=prompt, num_inference_steps=5, output_type="np").images[0] - image_slice = image[0, :10, :10, 0].flatten() - expected_slice = np.array( - [ - 0.668, - 0.668, - 0.6641, - 0.6602, - 0.6602, - 0.6562, - 0.6523, - 0.6484, - 0.6523, - 0.6562, - 0.668, - 0.668, - 0.6641, - 0.6641, - 0.6602, - 0.6562, - 0.6523, - 0.6484, - 0.6523, - 0.6562, - 0.668, - 0.668, - 0.668, - 0.6641, - 0.6602, - 0.6562, - 0.6523, - 0.6484, - 0.6523, - 0.6562, - ] - ) - - max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice) - self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}") From f27d122fe2f5a00c1311e684a1923b1010204191 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Thu, 31 Jul 2025 10:13:43 +0000 Subject: [PATCH 09/18] style fixes --- tests/pipelines/bria/test_pipeline_bria.py | 40 ++++++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/tests/pipelines/bria/test_pipeline_bria.py b/tests/pipelines/bria/test_pipeline_bria.py index 26c1771ee94b..2ded184e2109 100644 --- a/tests/pipelines/bria/test_pipeline_bria.py +++ b/tests/pipelines/bria/test_pipeline_bria.py @@ -28,9 +28,7 @@ ) from diffusers.pipelines.bria import BriaPipeline from diffusers.utils.testing_utils import ( - backend_empty_cache, enable_full_determinism, - nightly, numpy_cosine_similarity_distance, require_accelerator, require_torch_gpu, @@ -269,7 +267,7 @@ def get_inputs(self, device, seed=0): "output_type": "np", "generator": generator, } - + def test_bria_inference_bf16(self): pipe = self.pipeline_class.from_pretrained( self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, tokenizer=None @@ -282,11 +280,39 @@ def test_bria_inference_bf16(self): image_slice = image[0, :10, :10].flatten() expected_slice = np.array( - [0.59729785, 0.6153719, 0.595112, 0.5884763, 0.59366125, 0.5795311, 0.58325, 0.58449626, 0.57737637, 0.58432233, 0.5867875, 0.57824117, 0.5819089, 0.5830988, 0.57730293, 0.57647324, 0.5769151, 0.57312685, 0.57926565, 0.5823928, 0.57783926, 0.57162863, 0.575649, 0.5745547, 0.5740556, 0.5799735, 0.57799566, 0.5715559, 0.5771242, 0.5773058], + [ + 0.59729785, + 0.6153719, + 0.595112, + 0.5884763, + 0.59366125, + 0.5795311, + 0.58325, + 0.58449626, + 0.57737637, + 0.58432233, + 0.5867875, + 0.57824117, + 0.5819089, + 0.5830988, + 0.57730293, + 0.57647324, + 0.5769151, + 0.57312685, + 0.57926565, + 0.5823928, + 0.57783926, + 0.57162863, + 0.575649, + 0.5745547, + 0.5740556, + 0.5799735, + 0.57799566, + 0.5715559, + 0.5771242, + 0.5773058, + ], dtype=np.float32, ) max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice) self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}") - - - From a423221b8b2da023e25d2e20ec4848a00de28bb5 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Tue, 5 Aug 2025 14:20:05 +0000 Subject: [PATCH 10/18] small doc update --- docs/source/en/api/pipelines/bria_3_2.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/bria_3_2.md b/docs/source/en/api/pipelines/bria_3_2.md index 55689c7c9623..5f7731072c61 100644 --- a/docs/source/en/api/pipelines/bria_3_2.md +++ b/docs/source/en/api/pipelines/bria_3_2.md @@ -20,11 +20,13 @@ In addition to being built entirely on licensed data, 3.2 provides several advan - Fine-tuning Speedup: 2x faster fine-tuning on L40S and A100. Original model checkpoints for Bria 3.2 can be found [here](https://huggingface.co/briaai/BRIA-3.2). +Github repo for Bria 3.2 can be found [here](https://github.com/briaai/BRIA-3.2). + +If you want to learn more about the Bria platform, and get free traril access, please visit [bria.ai](https://bria.ai). ## Inference -The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma). ```python import torch From 5bfd7333f507cb9752c683347b62ffd92bd315e9 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Thu, 7 Aug 2025 11:09:55 +0000 Subject: [PATCH 11/18] Enhance Bria 3.2 documentation and implementation - Updated the GitHub repository link for Bria 3.2. - Added usage instructions for the gated model access. - Introduced the BriaTransformerBlock and BriaAttention classes to the model architecture. - Refactored existing classes to integrate Bria-specific components, including BriaEmbedND and BriaPipeline. - Updated the pipeline output class to reflect Bria-specific functionality. - Adjusted test cases to align with the new Bria model structure. --- docs/source/en/api/pipelines/bria_3_2.md | 12 +- src/diffusers/hooks/_helpers.py | 8 + .../models/transformers/transformer_bria.py | 435 ++++++++++++++---- src/diffusers/pipelines/bria/pipeline_bria.py | 206 ++++----- .../pipelines/bria/pipeline_output.py | 2 +- .../test_models_transformer_bria.py | 2 +- 6 files changed, 471 insertions(+), 194 deletions(-) diff --git a/docs/source/en/api/pipelines/bria_3_2.md b/docs/source/en/api/pipelines/bria_3_2.md index 5f7731072c61..19b53177549c 100644 --- a/docs/source/en/api/pipelines/bria_3_2.md +++ b/docs/source/en/api/pipelines/bria_3_2.md @@ -20,11 +20,21 @@ In addition to being built entirely on licensed data, 3.2 provides several advan - Fine-tuning Speedup: 2x faster fine-tuning on L40S and A100. Original model checkpoints for Bria 3.2 can be found [here](https://huggingface.co/briaai/BRIA-3.2). -Github repo for Bria 3.2 can be found [here](https://github.com/briaai/BRIA-3.2). +Github repo for Bria 3.2 can be found [here](https://github.com/Bria-AI/BRIA-3.2). If you want to learn more about the Bria platform, and get free traril access, please visit [bria.ai](https://bria.ai). +## Usage + +_As the model is gated, before using it with diffusers you first need to go to the [Bria 3.2 Hugging Face page](https://huggingface.co/briaai/BRIA-3.2), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._ + +Use the command below to log in: + +```bash +hf auth login +``` + ## Inference diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index f328078ce472..4ab0c1d5278b 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -143,6 +143,7 @@ def _register_attention_processors_metadata(): def _register_transformer_blocks_metadata(): from ..models.attention import BasicTransformerBlock from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock + from ..models.transformers.transformer_bria import BriaTransformerBlock from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock from ..models.transformers.transformer_hunyuan_video import ( @@ -164,6 +165,13 @@ def _register_transformer_blocks_metadata(): return_encoder_hidden_states_index=None, ), ) + TransformerBlockRegistry.register( + model_class=BriaTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) # CogVideoX TransformerBlockRegistry.register( diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py index 8e2a023b184a..6cfc17e2e247 100644 --- a/src/diffusers/models/transformers/transformer_bria.py +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -1,24 +1,57 @@ -from typing import Any, Dict, List, Optional, Union +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn -from packaging import version +import torch.nn.functional as F -from diffusers import __version__ as diffusers_version -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin -from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding -from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.normalization import AdaLayerNormContinuous -from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock -from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import TimestepEmbedding, apply_rotary_emb, get_timestep_embedding +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle logger = logging.get_logger(__name__) # pylint: disable=invalid-name +def _get_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + def get_1d_rotary_pos_embed( dim: int, pos: Union[np.ndarray, int], @@ -85,7 +118,232 @@ def get_1d_rotary_pos_embed( return freqs_cis -class EmbedND(torch.nn.Module): +class BriaAttnProcessor: + _attention_backend = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "BriaAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, key, value, attn_mask=attention_mask, backend=self._attention_backend + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class BriaAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = BriaAttnProcessor + _available_processors = [ + BriaAttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +@maybe_allow_in_graph +class BriaTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = BriaAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=BriaAttnProcessor(), + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class BriaEmbedND(torch.nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): super().__init__() @@ -115,7 +373,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin -class Timesteps(nn.Module): +class BriaTimesteps(nn.Module): def __init__( self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000 ): @@ -138,11 +396,11 @@ def forward(self, timesteps): return t_emb -class TimestepProjEmbeddings(nn.Module): +class BriaTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, time_theta): super().__init__() - self.time_proj = Timesteps( + self.time_proj = BriaTimesteps( num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta ) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) @@ -153,7 +411,7 @@ def forward(self, timestep, dtype): return timesteps_emb -class FluxPosEmbed(torch.nn.Module): +class BriaPosEmbed(torch.nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): super().__init__() @@ -183,18 +441,68 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin -""" -Based on FluxPipeline with several changes: -- no pooled embeddings -- We use zero padding for prompts -- No guidance embedding since this is not a distilled version -""" +@maybe_allow_in_graph +class BriaSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + processor = BriaAttnProcessor() + + self.attn = BriaAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states -class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): """ - The Transformer model introduced in Flux. - + The Transformer model introduced in Flux. Based on FluxPipeline with several changes: + - no pooled embeddings + - We use zero padding for prompts + - No guidance embedding since this is not a distilled version Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Parameters: @@ -231,22 +539,18 @@ def __init__( self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope) - - self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) - - # if pooled_projection_dim: - # self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu") + self.pos_embed = BriaEmbedND(theta=rope_theta, axes_dim=axes_dims_rope) + self.time_embed = BriaTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) if guidance_embeds: - self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim) + self.guidance_embed = BriaTimestepProjEmbeddings(embedding_dim=self.inner_dim) self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ - FluxTransformerBlock( + BriaTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, @@ -257,7 +561,7 @@ def __init__( self.single_transformer_blocks = nn.ModuleList( [ - FluxSingleTransformerBlock( + BriaSingleTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, @@ -271,10 +575,6 @@ def __init__( self.gradient_checkpointing = False - # def _set_gradient_checkpointing(self, module, enable=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = enable - def forward( self, hidden_states: torch.Tensor, @@ -337,17 +637,8 @@ def forward( else: guidance = None - # temb = ( - # self.time_text_embed(timestep, pooled_projections) - # if guidance is None - # else self.time_text_embed(timestep, guidance, pooled_projections) - # ) - temb = self.time_embed(timestep, dtype=hidden_states.dtype) - # if pooled_projections: - # temb+=self.pooled_text_embed(pooled_projections) - if guidance: temb += self.guidance_embed(guidance, dtype=hidden_states.dtype) @@ -398,9 +689,6 @@ def custom_forward(*inputs): interval_control = int(np.ceil(interval_control)) hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - if version.parse(diffusers_version) < version.parse("0.35.0.dev0"): - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - for index_block, block in enumerate(self.single_transformer_blocks): if self.training and self.gradient_checkpointing: @@ -414,38 +702,23 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - if version.parse(diffusers_version) < version.parse("0.35.0.dev0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, - ) + + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) else: - if version.parse(diffusers_version) < version.parse("0.35.0.dev0"): - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) # controlnet residual if controlnet_single_block_samples is not None: @@ -455,8 +728,6 @@ def custom_forward(*inputs): hidden_states[:, encoder_hidden_states.shape[1] :, ...] + controlnet_single_block_samples[index_block // interval_control] ) - if version.parse(diffusers_version) < version.parse("0.35.0.dev0"): - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py index 84571cc1f295..3d9f78cc348e 100644 --- a/src/diffusers/pipelines/bria/pipeline_bria.py +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -9,15 +9,19 @@ T5TokenizerFast, ) -from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler -from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import FluxLoraLoaderMixin -from diffusers.models.transformers.transformer_bria import BriaTransformer2DModel -from diffusers.pipelines.bria.pipeline_output import BriaPipelineOutput -from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers -from diffusers.utils import ( +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models import AutoencoderKL +from ...models.transformers.transformer_bria import BriaTransformer2DModel +from ...pipelines.bria.pipeline_output import BriaPipelineOutput +from ...pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps +from ...schedulers import ( + DDIMScheduler, + EulerAncestralDiscreteScheduler, + FlowMatchEulerDiscreteScheduler, + KarrasDiffusionSchedulers, +) +from ...utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, @@ -25,7 +29,7 @@ scale_lora_layers, unscale_lora_layers, ) -from diffusers.utils.torch_utils import randn_tensor +from ...utils.torch_utils import randn_tensor if is_torch_xla_available(): @@ -42,15 +46,23 @@ Examples: ```py >>> import torch - >>> from diffusers import StableDiffusion3Pipeline + >>> from diffusers import BriaPipeline - >>> pipe = StableDiffusion3Pipeline.from_pretrained( - ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 - ... ) + >>> pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.float16) >>> pipe.to("cuda") - >>> prompt = "A cat holding a sign that says hello world" + # BRIA's T5 text encoder is sensitive to precision. We need to cast it to float16 and keep the final layer in float32. + + >>> pipe.text_encoder = pipe.text_encoder.to(dtype=torch.float16) + >>> for block in pipe.text_encoder.encoder.block: + ... block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) + # BRIA's VAE is not supported in mixed precision, so we use float32. + + >>> if pipe.vae.config.shift_factor == 0: + ... pipe.vae.to(dtype=torch.float32) + + >>> prompt = "Photorealistic food photography of a stack of fluffy pancakes on a white plate, with maple syrup being poured over them. On top of the pancakes are the words 'BRIA 3.2' in bold, yellow, 3D letters. The background is dark and out of focus." >>> image = pipe(prompt).images[0] - >>> image.save("sd3.png") + >>> image.save("bria.png") ``` """ @@ -75,78 +87,22 @@ def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000): return new_sigmas -def get_t5_prompt_embeds( - tokenizer: T5TokenizerFast, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 128, - device: Optional[torch.device] = None, -): - device = device or text_encoder.device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - prompt_embeds_list = [] - for p in prompt: - text_inputs = tokenizer( - p, - # padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device))[0] - - # Concat zeros to max_sequence - b, seq_len, dim = prompt_embeds.shape - if seq_len < max_sequence_length: - padding = torch.zeros( - (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device - ) - prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=0) - prompt_embeds = prompt_embeds.to(device=device) - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1) - - return prompt_embeds - - -""" -Based on FluxPipeline with several changes: -- no pooled embeddings -- We use zero padding for prompts -- No guidance embedding since this is not a distilled version -""" - - class BriaPipeline(FluxPipeline): r""" + Based on FluxPipeline with several changes: + - no pooled embeddings + - We use zero padding for prompts + - No guidance embedding since this is not a distilled version + Args: - transformer ([`SD3Transformer2DModel`]): + transformer ([`BriaTransformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`T5EncoderModel`]): - Frozen text-encoder. Stable Diffusion 3 uses + Frozen text-encoder. Bria uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. tokenizer (`T5TokenizerFast`): @@ -248,14 +204,12 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds = get_t5_prompt_embeds( - self.tokenizer, - self.text_encoder, + prompt_embeds = self._get_t5_prompt_embeds( prompt=prompt, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, - ).to(dtype=self.transformer.dtype) + ) if do_classifier_free_guidance and negative_prompt_embeds is None: if not is_ng_none(negative_prompt): @@ -275,14 +229,12 @@ def encode_prompt( " the batch size of `prompt`." ) - negative_prompt_embeds = get_t5_prompt_embeds( - self.tokenizer, - self.text_encoder, + negative_prompt_embeds = self._get_t5_prompt_embeds( prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, - ).to(dtype=self.transformer.dtype) + ) else: negative_prompt_embeds = torch.zeros_like(prompt_embeds) @@ -291,8 +243,7 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device) text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) return prompt_embeds, negative_prompt_embeds, text_ids @@ -310,7 +261,7 @@ def do_classifier_free_guidance(self): @property def joint_attention_kwargs(self): - return self._joint_attention_kwargs + return self.attention_kwargs @property def num_timesteps(self): @@ -338,7 +289,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 128, @@ -393,8 +344,7 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead - of a plain tuple. + Whether or not to return a [`~pipelines.bria.BriaPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -413,7 +363,7 @@ def __call__( Examples: Returns: - [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + [`~pipelines.bria.BriaPipelineOutput`] or `tuple`: [`~pipelines.bria.BriaPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ @@ -432,7 +382,7 @@ def __call__( ) self._guidance_scale = guidance_scale - self._joint_attention_kwargs = joint_attention_kwargs + self.attention_kwargs = attention_kwargs self._interrupt = False # 2. Define call parameters @@ -521,8 +471,6 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # Supprot different diffusers versions - if len(latent_image_ids.shape) == 3: latent_image_ids = latent_image_ids[0] if len(text_ids.shape) == 3: @@ -620,8 +568,6 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - # if height % 8 != 0 or width % 8 != 0: - # raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: logger.warning( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" @@ -662,18 +608,60 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - def to(self, *args, **kwargs): - DiffusionPipeline.to(self, *args, **kwargs) - # T5 is senstive to precision so we use the precision used for precompute and cast as needed - if self.text_encoder is not None: - self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION) - for block in self.text_encoder.encoder.block: - block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + ): + tokenizer = self.tokenizer + text_encoder = self.text_encoder + device = device or text_encoder.device - if self.vae.config.shift_factor == 0 and self.vae.dtype != torch.float32: - self.vae.to(dtype=torch.float32) + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + prompt_embeds_list = [] + for p in prompt: + text_inputs = tokenizer( + p, + # padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + # Concat zeros to max_sequence + b, seq_len, dim = prompt_embeds.shape + if seq_len < max_sequence_length: + padding = torch.zeros( + (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=0) + prompt_embeds = prompt_embeds.to(device=device) - return self + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + return prompt_embeds def prepare_latents( self, diff --git a/src/diffusers/pipelines/bria/pipeline_output.py b/src/diffusers/pipelines/bria/pipeline_output.py index 2cda68de292f..54eed0623371 100644 --- a/src/diffusers/pipelines/bria/pipeline_output.py +++ b/src/diffusers/pipelines/bria/pipeline_output.py @@ -10,7 +10,7 @@ @dataclass class BriaPipelineOutput(BaseOutput): """ - Output class for Stable Diffusion pipelines. + Output class for Bria pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) diff --git a/tests/models/transformers/test_models_transformer_bria.py b/tests/models/transformers/test_models_transformer_bria.py index d95e2950f91d..8a8d0dcecffc 100644 --- a/tests/models/transformers/test_models_transformer_bria.py +++ b/tests/models/transformers/test_models_transformer_bria.py @@ -28,7 +28,7 @@ enable_full_determinism() -def create_chroma_ip_adapter_state_dict(model): +def create_bria_ip_adapter_state_dict(model): # "ip_adapter" (cross-attention weights) ip_cross_attn_state_dict = {} key_id = 0 From d2fba0af486a29d4465c6a2a6816ff2d5553cc9f Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Sun, 10 Aug 2025 08:53:01 +0000 Subject: [PATCH 12/18] Refactor Bria model components and update documentation - Removed outdated inference example from Bria 3.2 documentation. - Introduced the BriaTransformerBlock class to enhance model architecture. - Updated attention handling to use `attention_kwargs` instead of `joint_attention_kwargs`. - Improved import structure in the Bria pipeline to handle optional dependencies. - Adjusted test cases to reflect changes in model dtype assertions. --- docs/source/en/api/pipelines/bria_3_2.md | 16 -- src/diffusers/models/embeddings.py | 1 + .../models/transformers/transformer_bria.py | 235 ++++++++---------- src/diffusers/pipelines/bria/__init__.py | 33 ++- src/diffusers/pipelines/bria/pipeline_bria.py | 28 +-- tests/pipelines/bria/test_pipeline_bria.py | 2 +- 6 files changed, 147 insertions(+), 168 deletions(-) diff --git a/docs/source/en/api/pipelines/bria_3_2.md b/docs/source/en/api/pipelines/bria_3_2.md index 19b53177549c..059fa01f9f83 100644 --- a/docs/source/en/api/pipelines/bria_3_2.md +++ b/docs/source/en/api/pipelines/bria_3_2.md @@ -35,22 +35,6 @@ Use the command below to log in: hf auth login ``` -## Inference - - -```python -import torch -from diffusers import BriaPipeline - -pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2",revision="bria_3_2_diffusers", torch_dtype=torch.bfloat16) -pipe.enable_model_cpu_offload() - -prompt = "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done." - -image = pipe(prompt).images[0] -image.save("bria.png") -``` - ## BriaPipeline diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 051a776e49fd..b51f5d7aec25 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1220,6 +1220,7 @@ def apply_rotary_emb( x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py index 6cfc17e2e247..27a9941501a1 100644 --- a/src/diffusers/models/transformers/transformer_bria.py +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -103,7 +103,7 @@ def get_1d_rotary_pos_embed( ) # [D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: - # flux, hunyuan-dit, cogvideox + # bria freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin @@ -252,97 +252,12 @@ def forward( unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] if len(unused_kwargs) > 0: logger.warning( - f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." ) kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) -@maybe_allow_in_graph -class BriaTransformerBlock(nn.Module): - def __init__( - self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 - ): - super().__init__() - - self.norm1 = AdaLayerNormZero(dim) - self.norm1_context = AdaLayerNormZero(dim) - - self.attn = BriaAttention( - query_dim=dim, - added_kv_proj_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - context_pre_only=False, - bias=True, - processor=BriaAttnProcessor(), - eps=eps, - ) - - self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - temb: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) - - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( - encoder_hidden_states, emb=temb - ) - joint_attention_kwargs = joint_attention_kwargs or {} - - # Attention. - attention_outputs = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - **joint_attention_kwargs, - ) - - if len(attention_outputs) == 2: - attn_output, context_attn_output = attention_outputs - elif len(attention_outputs) == 3: - attn_output, context_attn_output, ip_attn_output = attention_outputs - - # Process attention outputs for the `hidden_states`. - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = hidden_states + ff_output - if len(attention_outputs) == 3: - hidden_states = hidden_states + ip_attn_output - - # Process attention outputs for the `encoder_hidden_states`. - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - - context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output - if encoder_hidden_states.dtype == torch.float16: - encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - - return encoder_hidden_states, hidden_states - - class BriaEmbedND(torch.nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): @@ -441,6 +356,91 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin +@maybe_allow_in_graph +class BriaTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = BriaAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=BriaAttnProcessor(), + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + attention_kwargs = attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + @maybe_allow_in_graph class BriaSingleTransformerBlock(nn.Module): def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): @@ -471,7 +471,7 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: text_seq_len = encoder_hidden_states.shape[1] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -479,11 +479,11 @@ def forward( residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - joint_attention_kwargs = joint_attention_kwargs or {} + attention_kwargs = attention_kwargs or {} attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, - **joint_attention_kwargs, + **attention_kwargs, ) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) @@ -584,13 +584,13 @@ def forward( img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, controlnet_block_samples=None, controlnet_single_block_samples=None, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ - The [`FluxTransformer2DModel`] forward method. + The [`BriaTransformer2DModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): @@ -603,7 +603,7 @@ def forward( Used to indicate denoising step. block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): + attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -615,9 +615,9 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 @@ -625,9 +625,9 @@ def forward( # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) hidden_states = self.x_embedder(hidden_states) @@ -654,25 +654,14 @@ def forward( image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, - **ckpt_kwargs, + attention_kwargs, ) else: @@ -690,26 +679,14 @@ def custom_forward(*inputs): hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - + if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - create_custom_forward(block), + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, - **ckpt_kwargs, + attention_kwargs, ) else: diff --git a/src/diffusers/pipelines/bria/__init__.py b/src/diffusers/pipelines/bria/__init__.py index 88e51b534ab0..60e319ac7910 100644 --- a/src/diffusers/pipelines/bria/__init__.py +++ b/src/diffusers/pipelines/bria/__init__.py @@ -2,16 +2,38 @@ from ...utils import ( DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, ) -_import_structure = { - "pipeline_bria": ["BriaPipeline"], -} +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_bria"] = ["BriaPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .pipeline_bria import BriaPipeline + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_bria import BriaPipeline + else: import sys @@ -21,3 +43,6 @@ _import_structure, module_spec=__spec__, ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py index 3d9f78cc348e..1ba726a1220d 100644 --- a/src/diffusers/pipelines/bria/pipeline_bria.py +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -66,8 +66,6 @@ ``` """ -T5_PRECISION = torch.float16 - def is_ng_none(negative_prompt): return ( @@ -134,19 +132,12 @@ def __init__( feature_extractor=feature_extractor, ) - # TODO - why different than offical flux (-1) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k - # T5 is senstive to precision so we use the precision used for precompute and cast as needed - if self.text_encoder is not None: - self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION) - for block in self.text_encoder.encoder.block: - block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) - if self.vae.config.shift_factor is None: self.vae.config.shift_factor = 0 self.vae.to(dtype=torch.float32) @@ -260,8 +251,12 @@ def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property - def joint_attention_kwargs(self): - return self.attention_kwargs + def attention_kwargs(self): + return self._attention_kwargs + + @attention_kwargs.setter + def attention_kwargs(self, value): + self._attention_kwargs = value @property def num_timesteps(self): @@ -345,7 +340,7 @@ def __call__( [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.bria.BriaPipelineOutput`] instead of a plain tuple. - joint_attention_kwargs (`dict`, *optional*): + attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -395,9 +390,7 @@ def __call__( device = self._execution_device - lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None - ) + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None (prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt( prompt=prompt, @@ -432,8 +425,7 @@ def __call__( and self.scheduler.config["use_dynamic_shifting"] ): sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = latents.shape[1] # Shift by height - Why just height? - print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}") + image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, @@ -495,7 +487,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, + attention_kwargs=self.attention_kwargs, return_dict=False, txt_ids=text_ids, img_ids=latent_image_ids, diff --git a/tests/pipelines/bria/test_pipeline_bria.py b/tests/pipelines/bria/test_pipeline_bria.py index 2ded184e2109..e6dec4ddc0b9 100644 --- a/tests/pipelines/bria/test_pipeline_bria.py +++ b/tests/pipelines/bria/test_pipeline_bria.py @@ -211,7 +211,7 @@ def test_to_dtype(self): pipe.set_progress_bar_config(disable=None) model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] - self.assertTrue([dtype == torch.float32 for dtype in model_dtypes] == [False, True, True]) + self.assertTrue([dtype == torch.float32 for dtype in model_dtypes] == [True, True, True]) def test_torch_dtype_dict(self): components = self.get_dummy_components() From d7dd8b39e479a68223caec3319301149791c2554 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Fri, 15 Aug 2025 03:07:05 +0000 Subject: [PATCH 13/18] Update Bria model reference in documentation to reflect new file naming convention --- docs/source/en/_toctree.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 03d56de2876e..1832f4baa115 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -326,7 +326,7 @@ title: AllegroTransformer3DModel - local: api/models/aura_flow_transformer2d title: AuraFlowTransformer2DModel - - local: api/models/bria_transformer2d + - local: api/models/transformer_bria title: BriaTransformer2DModel - local: api/models/chroma_transformer title: ChromaTransformer2DModel From 36b0133c4f52b905dfce1368863d2164d2c703e7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 18 Aug 2025 03:55:01 +0530 Subject: [PATCH 14/18] Update docs/source/en/_toctree.yml --- docs/source/en/_toctree.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 290b52f0967a..74097b882175 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -340,7 +340,7 @@ title: AllegroTransformer3DModel - local: api/models/aura_flow_transformer2d title: AuraFlowTransformer2DModel - - local: api/models/transformer_bria + - local: api/models/bria_transformer title: BriaTransformer2DModel - local: api/models/chroma_transformer title: ChromaTransformer2DModel From 9c6d9dd734a1453919a413eb4eceda6397732757 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Mon, 18 Aug 2025 15:14:04 +0000 Subject: [PATCH 15/18] Refactor BriaPipeline to inherit from DiffusionPipeline instead of FluxPipeline, updating imports accordingly. --- src/diffusers/pipelines/bria/pipeline_bria.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py index 1ba726a1220d..f3c0b166626f 100644 --- a/src/diffusers/pipelines/bria/pipeline_bria.py +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -13,8 +13,9 @@ from ...loaders import FluxLoraLoaderMixin from ...models import AutoencoderKL from ...models.transformers.transformer_bria import BriaTransformer2DModel +from ...pipelines import DiffusionPipeline from ...pipelines.bria.pipeline_output import BriaPipelineOutput -from ...pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps +from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps from ...schedulers import ( DDIMScheduler, EulerAncestralDiscreteScheduler, @@ -85,7 +86,7 @@ def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000): return new_sigmas -class BriaPipeline(FluxPipeline): +class BriaPipeline(DiffusionPipeline): r""" Based on FluxPipeline with several changes: - no pooled embeddings From f70229290e9a6551fb35a2dd6be459f83b5b1673 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Tue, 19 Aug 2025 12:29:03 +0000 Subject: [PATCH 16/18] move the __call__ func to the end of file --- src/diffusers/pipelines/bria/pipeline_bria.py | 356 +++++++++--------- 1 file changed, 178 insertions(+), 178 deletions(-) diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py index f3c0b166626f..15948a46d545 100644 --- a/src/diffusers/pipelines/bria/pipeline_bria.py +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -267,6 +267,184 @@ def num_timesteps(self): def interrupt(self): return self._interrupt + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + ): + tokenizer = self.tokenizer + text_encoder = self.text_encoder + device = device or text_encoder.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + prompt_embeds_list = [] + for p in prompt: + text_inputs = tokenizer( + p, + # padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + # Concat zeros to max_sequence + b, seq_len, dim = prompt_embeds.shape + if seq_len < max_sequence_length: + padding = torch.zeros( + (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=0) + prompt_embeds = prompt_embeds.to(device=device) + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + return prompt_embeds + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -549,181 +727,3 @@ def __call__( return (image,) return BriaPipelineOutput(images=image) - - def check_inputs( - self, - prompt, - height, - width, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - max_sequence_length=None, - ): - if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: - logger.warning( - f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" - ) - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if max_sequence_length is not None and max_sequence_length > 512: - raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 128, - device: Optional[torch.device] = None, - ): - tokenizer = self.tokenizer - text_encoder = self.text_encoder - device = device or text_encoder.device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - prompt_embeds_list = [] - for p in prompt: - text_inputs = tokenizer( - p, - # padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device))[0] - - # Concat zeros to max_sequence - b, seq_len, dim = prompt_embeds.shape - if seq_len < max_sequence_length: - padding = torch.zeros( - (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device - ) - prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=0) - prompt_embeds = prompt_embeds.to(device=device) - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1) - prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) - return prompt_embeds - - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) - - shape = (batch_size, num_channels_latents, height, width) - - if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - - return latents, latent_image_ids - - @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - height = height // vae_scale_factor - width = width // vae_scale_factor - - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) - - return latents - - @staticmethod - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1) - latent_image_ids = latent_image_ids.reshape( - batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) From 4430fe9422bf7f576fb5add1bf0d8ec62d01a74a Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Wed, 20 Aug 2025 08:37:22 +0000 Subject: [PATCH 17/18] Update BriaPipeline example to use bfloat16 for precision sensitivity for better result --- src/diffusers/pipelines/bria/pipeline_bria.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py index 15948a46d545..c3d6ae1ac5df 100644 --- a/src/diffusers/pipelines/bria/pipeline_bria.py +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -49,11 +49,11 @@ >>> import torch >>> from diffusers import BriaPipeline - >>> pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.float16) + >>> pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - # BRIA's T5 text encoder is sensitive to precision. We need to cast it to float16 and keep the final layer in float32. + # BRIA's T5 text encoder is sensitive to precision. We need to cast it to bfloat16 and keep the final layer in float32. - >>> pipe.text_encoder = pipe.text_encoder.to(dtype=torch.float16) + >>> pipe.text_encoder = pipe.text_encoder.to(dtype=torch.bfloat16) >>> for block in pipe.text_encoder.encoder.block: ... block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) # BRIA's VAE is not supported in mixed precision, so we use float32. @@ -267,6 +267,7 @@ def num_timesteps(self): def interrupt(self): return self._interrupt + def check_inputs( self, prompt, From 3b847aba8541a893cb1a45088a01657daf9b7939 Mon Sep 17 00:00:00 2001 From: Gal Davidi Date: Wed, 20 Aug 2025 08:47:08 +0000 Subject: [PATCH 18/18] make style && make quality && make fix-copiessource --- src/diffusers/pipelines/bria/pipeline_bria.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py index c3d6ae1ac5df..39ed484793d5 100644 --- a/src/diffusers/pipelines/bria/pipeline_bria.py +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -267,7 +267,6 @@ def num_timesteps(self): def interrupt(self): return self._interrupt - def check_inputs( self, prompt,