diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py index dd70fb82fff4..c34415595507 100644 --- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py @@ -41,7 +41,7 @@ replace_example_docstring, ) from ...utils.import_utils import is_transformers_version -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import empty_device_cache, randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel @@ -267,9 +267,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) - device_mod = getattr(torch, device.type, None) - if hasattr(device_mod, "empty_cache") and device_mod.is_available(): - device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + empty_device_cache(device.type) model_sequence = [ self.text_encoder.text_model, diff --git a/src/diffusers/pipelines/consisid/consisid_utils.py b/src/diffusers/pipelines/consisid/consisid_utils.py index 23811a4986e3..521d4d787e54 100644 --- a/src/diffusers/pipelines/consisid/consisid_utils.py +++ b/src/diffusers/pipelines/consisid/consisid_utils.py @@ -294,7 +294,7 @@ def prepare_face_models(model_path, device, dtype): Parameters: - model_path: Path to the directory containing model files. - - device: The device (e.g., 'cuda', 'cpu') where models will be loaded. + - device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded. - dtype: Data type (e.g., torch.float32) for model inference. Returns: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index e14c51ab94ce..12aac7ff0d8a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -37,7 +37,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1339,7 +1339,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 2102b34d5ffb..1fdc285c26e9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -36,7 +36,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1311,7 +1311,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 3fc1206ba5f6..bff14b3fbeb5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -38,7 +38,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1500,7 +1500,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 51ddd997142b..b38b60a20ae1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -51,7 +51,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -1858,7 +1858,7 @@ def denoising_value_valid(dnv): if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 8aa47fd4277a..d178f382db47 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1465,7 +1465,11 @@ def __call__( # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if ( + torch.cuda.is_available() + and (is_unet_compiled and is_controlnet_compiled) + and is_torch_higher_equal_2_1 + ): torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 52e76a220454..08c6d02f27f9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -53,7 +53,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -921,7 +921,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() image = image.to(device=device, dtype=dtype) @@ -1632,7 +1632,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py index fc160573ba4e..a0b59dd0f851 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py @@ -51,7 +51,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -1766,7 +1766,7 @@ def denoising_value_valid(dnv): if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index bd98b7975082..3bcaf228abf9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -53,7 +53,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -876,7 +876,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() image = image.to(device=device, dtype=dtype) @@ -1574,7 +1574,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index e780d6966a82..874b4531be78 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -36,7 +36,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -853,7 +853,7 @@ def __call__( for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if is_controlnet_compiled and is_torch_higher_equal_2_1: + if torch.cuda.is_available() and is_controlnet_compiled and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -902,7 +902,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py index f6d445a4869e..25f90d58a423 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py @@ -193,7 +193,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a @@ -411,7 +411,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -652,7 +652,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py index 2ebd995eb58d..c104e4bcced0 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py @@ -179,7 +179,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -407,7 +407,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` @@ -417,7 +417,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -656,7 +656,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index 4b0f8cdb62d3..93dcd36a7021 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -25,7 +25,7 @@ from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import empty_device_cache, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import KolorsPipelineOutput from .text_encoder import ChatGLMModel @@ -618,7 +618,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() image = image.to(device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py index d340738c86ca..e243100d1a73 100644 --- a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py +++ b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py @@ -35,7 +35,7 @@ logging, replace_example_docstring, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import empty_device_cache, get_device, randn_tensor from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin @@ -397,20 +397,22 @@ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, devic def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its + `forward` method is called, and the model remains in accelerator until the next model runs. Memory savings are + lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution + of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") - device = torch.device(f"cuda:{gpu_id}") + device_type = get_device() + device = torch.device(f"{device_type}:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + empty_device_cache() # otherwise we don't see the memory savings (but they probably exist) model_sequence = [ self.text_encoder.text_model, diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py index 205868762935..2d472382da41 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py @@ -36,7 +36,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1228,7 +1228,11 @@ def __call__( for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if ( + torch.cuda.is_available() + and (is_unet_compiled and is_controlnet_compiled) + and is_torch_higher_equal_2_1 + ): torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) @@ -1309,7 +1313,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py index f8a58a665475..35a62b3d474a 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py @@ -37,7 +37,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1521,7 +1521,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index 259b8939ce99..dbabfe423d47 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -1498,7 +1498,11 @@ def __call__( for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if ( + torch.cuda.is_available() + and (is_unet_compiled and is_controlnet_compiled) + and is_torch_higher_equal_2_1 + ): torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py index dc20ea95cdbe..3b4e608153e3 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -52,7 +52,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from .pag_utils import PAGMixin @@ -926,7 +926,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() image = image.to(device=device, dtype=dtype) @@ -1648,7 +1648,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 02ede6c3d6c6..78320223491c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -35,7 +35,7 @@ logging, replace_example_docstring, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pixart_alpha.pipeline_pixart_alpha import ( ASPECT_RATIO_512_BIN, @@ -917,9 +917,15 @@ def __call__( image = latents else: latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - except torch.cuda.OutOfMemoryError as e: + except oom_error as e: warnings.warn( f"{e}. \n" f"Try to use VAE tiling for large images. For example: \n" diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index 92eb45a72e7f..74009c581a42 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -49,7 +49,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import empty_device_cache, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from .pag_utils import PAGMixin @@ -716,7 +716,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() image = image.to(device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index efeb085a723b..4fb3a541f509 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -67,7 +67,7 @@ numpy_to_pil, ) from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card -from ..utils.torch_utils import get_device, is_compiled_module +from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module if is_torch_npu_available(): @@ -1201,9 +1201,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t self._offload_device = device self.to("cpu", silence_dtype_warnings=True) - device_mod = getattr(torch, device.type, None) - if hasattr(device_mod, "empty_cache") and device_mod.is_available(): - device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + empty_device_cache(device.type) all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} @@ -1313,10 +1311,9 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un self._offload_device = device if self.device.type != "cpu": + orig_device_type = self.device.type self.to("cpu", silence_dtype_warnings=True) - device_mod = getattr(torch, self.device.type, None) - if hasattr(device_mod, "empty_cache") and device_mod.is_available(): - device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + empty_device_cache(orig_device_type) for name, model in self.components.items(): if not isinstance(model, torch.nn.Module): diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 6c2c1fb0a8f1..8a7d574cd5df 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -38,7 +38,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ( ASPECT_RATIO_512_BIN, @@ -982,9 +982,15 @@ def __call__( image = latents else: latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - except torch.cuda.OutOfMemoryError as e: + except oom_error as e: warnings.warn( f"{e}. \n" f"Try to use VAE tiling for large images. For example: \n" diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py index 593e0895e43d..bf96fe902e3a 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -38,7 +38,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ( ASPECT_RATIO_512_BIN, @@ -1078,9 +1078,15 @@ def __call__( image = latents else: latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - except torch.cuda.OutOfMemoryError as e: + except oom_error as e: warnings.warn( f"{e}. \n" f"Try to use VAE tiling for large images. For example: \n" diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 46edbf7c33ef..d13506cc7120 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -38,7 +38,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN from .pipeline_output import SanaPipelineOutput @@ -864,9 +864,15 @@ def __call__( image = latents else: latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - except torch.cuda.OutOfMemoryError as e: + except oom_error as e: warnings.warn( f"{e}. \n" f"Try to use VAE tiling for large images. For example: \n" diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py index f71b980ffc84..99cc101f06d0 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py @@ -39,7 +39,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN from .pipeline_output import SanaPipelineOutput @@ -952,9 +952,15 @@ def __call__( image = latents else: latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - except torch.cuda.OutOfMemoryError as e: + except oom_error as e: warnings.warn( f"{e}. \n" f"Try to use VAE tiling for large images. For example: \n" diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py index a7c273fbe1e9..c387ab7b0501 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py @@ -125,7 +125,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` @@ -135,7 +135,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 568ae7f7d671..6c0221d2092a 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -53,6 +53,7 @@ ) from ...utils import is_accelerate_available, logging from ...utils.constants import DIFFUSERS_REQUEST_TIMEOUT +from ...utils.torch_utils import get_device from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from ..paint_by_example import PaintByExampleImageEncoder from ..pipeline_utils import DiffusionPipeline @@ -1272,7 +1273,7 @@ def download_from_original_stable_diffusion_ckpt( checkpoint = safe_load(checkpoint_path_or_dict, device="cpu") else: if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + device = get_device() checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) else: checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) @@ -1842,7 +1843,7 @@ def download_controlnet_from_original_ckpt( checkpoint[key] = f.get_tensor(key) else: if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + device = get_device() checkpoint = torch.load(checkpoint_path, map_location=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index dfbbdaeac5a3..d4248a14e560 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -50,7 +50,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import empty_device_cache, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput @@ -704,7 +704,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() image = image.to(device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py index 800cf7b65af9..96316f8e91e5 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -23,7 +23,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import empty_device_cache, randn_tensor from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionSafetyChecker @@ -760,7 +760,7 @@ def __call__( # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if output_type == "latent": image = latents diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index 0e7158ee6b79..829fe23bc0ee 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -113,7 +113,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` @@ -123,7 +123,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 7d0a6faa7afb..dae2a257d88c 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -22,6 +22,7 @@ from packaging import version from .import_utils import is_peft_available, is_torch_available +from .torch_utils import empty_device_cache if is_torch_available(): @@ -95,8 +96,7 @@ def recurse_remove_peft_layers(model): setattr(model, name, new_module) del module - if torch.cuda.is_available(): - torch.cuda.empty_cache() + empty_device_cache() return model diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index bb5674092d09..35bcd46ede5b 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -172,3 +172,10 @@ def get_device(): return "xpu" else: return "cpu" + + +def empty_device_cache(device_type: Optional[str] = None): + if device_type is None: + device_type = get_device() + device_mod = getattr(torch, device_type, torch.cuda) + device_mod.empty_cache()