diff --git a/flux_pipeline.py b/flux_pipeline.py index 9b84954..da766a5 100644 --- a/flux_pipeline.py +++ b/flux_pipeline.py @@ -26,6 +26,7 @@ config.cache_size_limit = 10000000000 ind_config.shape_padding = True +config.suppress_errors = True import platform from loguru import logger @@ -153,6 +154,7 @@ def load_lora( lora_path: Union[str, OrderedDict[str, torch.Tensor]], scale: float, name: Optional[str] = None, + silent=False ): """ Loads a LoRA checkpoint into the Flux flow transformer. @@ -165,16 +167,16 @@ def load_lora( scale (float): Scaling factor for the LoRA weights. name (str): Name of the LoRA checkpoint, optionally can be left as None, since it only acts as an identifier. """ - self.model.load_lora(path=lora_path, scale=scale, name=name) + self.model.load_lora(path=lora_path, scale=scale, name=name, silent=silent) - def unload_lora(self, path_or_identifier: str): + def unload_lora(self, path_or_identifier: str, silent=False): """ Unloads the LoRA checkpoint from the Flux flow transformer. Args: path_or_identifier (str): Path to the LoRA checkpoint or the name given to the LoRA checkpoint when it was loaded. """ - self.model.unload_lora(path_or_identifier=path_or_identifier) + self.model.unload_lora(path_or_identifier=path_or_identifier, silent=silent) @torch.inference_mode() def compile(self): diff --git a/lora_loading.py b/lora_loading.py index cfd8dd1..81f7188 100644 --- a/lora_loading.py +++ b/lora_loading.py @@ -181,9 +181,7 @@ def convert_diffusers_to_flux_transformer_checkpoint( dtype = sample_component_A.dtype device = sample_component_A.device else: - logger.info( - f"Skipping layer {i} since no LoRA weight is available for {sample_component_A_key}" - ) + logger.info(f"Skipping layer {i} since no LoRA weight is available for {sample_component_A_key}") temp_dict[f"{component}"] = [None, None] if device is not None: @@ -344,30 +342,26 @@ def convert_diffusers_to_flux_transformer_checkpoint( shape_qkv_a = None shape_qkv_b = None # Q, K, V, mlp - q_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight") - q_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight") + q_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight", None) + q_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight", None) if q_A is not None and q_B is not None: has_q = True shape_qkv_a = q_A.shape shape_qkv_b = q_B.shape - k_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight") - k_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight") + k_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight", None) + k_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight", None) if k_A is not None and k_B is not None: has_k = True shape_qkv_a = k_A.shape shape_qkv_b = k_B.shape - v_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight") - v_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight") + v_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight", None) + v_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight", None) if v_A is not None and v_B is not None: has_v = True shape_qkv_a = v_A.shape shape_qkv_b = v_B.shape - mlp_A = diffusers_state_dict.pop( - f"{prefix}{block_prefix}proj_mlp.lora_A.weight" - ) - mlp_B = diffusers_state_dict.pop( - f"{prefix}{block_prefix}proj_mlp.lora_B.weight" - ) + mlp_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}proj_mlp.lora_A.weight", None) + mlp_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}proj_mlp.lora_B.weight", None) if mlp_A is not None and mlp_B is not None: has_mlp = True shape_qkv_a = mlp_A.shape @@ -637,6 +631,7 @@ def apply_lora_to_model( lora_path: str | StateDict, lora_scale: float = 1.0, return_lora_resolved: bool = False, + silent=False ) -> Flux: has_guidance = model.params.guidance_embed logger.info(f"Loading LoRA weights for {lora_path}") @@ -675,7 +670,7 @@ def apply_lora_to_model( ] ) ) - for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)): + for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab), disable=silent): module = get_module_for_key(key, model) weight, is_f8, dtype = extract_weight_from_linear(module) lora_sd = get_lora_for_key(key, lora_weights) @@ -697,6 +692,7 @@ def remove_lora_from_module( model: Flux, lora_path: str | StateDict, lora_scale: float = 1.0, + silent=False ): has_guidance = model.params.guidance_embed logger.info(f"Loading LoRA weights for {lora_path}") @@ -737,7 +733,7 @@ def remove_lora_from_module( ) ) - for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)): + for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab), disable=silent): module = get_module_for_key(key, model) weight, is_f8, dtype = extract_weight_from_linear(module) lora_sd = get_lora_for_key(key, lora_weights) diff --git a/modules/flux_model.py b/modules/flux_model.py index 0da932b..f924a76 100644 --- a/modules/flux_model.py +++ b/modules/flux_model.py @@ -628,7 +628,7 @@ def has_lora(self, identifier: str): if lora.path == identifier or lora.name == identifier: return True - def load_lora(self, path: str, scale: float, name: str = None): + def load_lora(self, path: str, scale: float, name: str = None, silent=False): from lora_loading import ( LoraWeights, apply_lora_to_model, @@ -642,23 +642,23 @@ def load_lora(self, path: str, scale: float, name: str = None): f"Lora {lora.name} already loaded with same scale - ignoring!" ) else: - remove_lora_from_module(self, lora, lora.scale) - apply_lora_to_model(self, lora, scale) + remove_lora_from_module(self, lora, lora.scale, silent=silent) + apply_lora_to_model(self, lora, scale, silent=silent) for idx, lora_ in enumerate(self.loras): if lora_.path == lora.path: self.loras[idx].scale = scale break else: - _, lora = apply_lora_to_model(self, path, scale, return_lora_resolved=True) + _, lora = apply_lora_to_model(self, path, scale, return_lora_resolved=True, silent=silent) self.loras.append(LoraWeights(lora, path, name, scale)) - def unload_lora(self, path_or_identifier: str): + def unload_lora(self, path_or_identifier: str, silent=False): from lora_loading import remove_lora_from_module removed = False for idx, lora_ in enumerate(list(self.loras)): if lora_.path == path_or_identifier or lora_.name == path_or_identifier: - remove_lora_from_module(self, lora_.weights, lora_.scale) + remove_lora_from_module(self, lora_.weights, lora_.scale, silent=silent) self.loras.pop(idx) removed = True break