Skip to content

Loading pipeline in precision it was saved in #9797

Open
@mvafin

Description

@mvafin

Is your feature request related to a problem? Please describe.
Currently, if torch_dtype is not specified, the pipeline defaults to loading in float32. This behavior causes float16 or bfloat16 weights to be upcast to float32 when the model is saved in lower precision, leading to increased memory usage. In scenarios where memory efficiency is critical (e.g., when exporting the model to another format), it’s important to load the model in the original precision specified in the safetensors file. Additionally, there’s currently no way to determine the dtype the model was saved in.

Describe the solution you'd like.
A feature similar to torch_dtype="auto" in the transformers library would be helpful. This option allows models to be loaded with the dtype defined in their configuration. However, diffuser pipeline models generally lack a dtype specification in their configs. It is sometimes possible to use torch_dtype from text_encoder config, but not all pipelines have it and it is not clear if this is a reliable place to check the precision of the model.

Describe alternatives you've considered.
A possible solution could be implementing a method to identify the model’s precision prior to calling from_pretrained, as the weights are accessible only after the model is downloaded inside from_pretrained and remain hidden from external access. This approach would allow users to set the appropriate torch_dtype for loading the model.

Additional context.
This feature is relevant to optimum-cli use cases where model conversion or export to other formats must work within memory constraints. If there’s already a way to achieve this, guidance would be appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions