File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -2954,10 +2954,12 @@ def _get_and_verify_dtype(
2954
2954
) -> torch .dtype :
2955
2955
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
2956
2956
# because config.torch_dtype can be None.
2957
- config_dtype = getattr (config . get_text_config () , "torch_dtype" , None )
2957
+ config_dtype = getattr (config , "torch_dtype" , None )
2958
2958
2959
- # Fallback for multi-modal models if the root config
2959
+ # Fallbacks for multi-modal models if the root config
2960
2960
# does not define torch_dtype
2961
+ if config_dtype is None :
2962
+ config_dtype = getattr (config .get_text_config (), "torch_dtype" , None )
2961
2963
if config_dtype is None and hasattr (config , "vision_config" ):
2962
2964
config_dtype = getattr (config .vision_config , "torch_dtype" , None )
2963
2965
You can’t perform that action at this time.
0 commit comments