diff --git a/intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py b/intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py index c945e48df..d959fbe68 100644 --- a/intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py +++ b/intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py @@ -256,7 +256,11 @@ def get_parammeter_from_model(model, name_list): name_list = name_list[1:] model_or_param = model for attr in name_list: - model_or_param = getattr(model_or_param, attr) + model_or_param_new = getattr(model_or_param, attr, None) + if model_or_param_new is None: + if getattr(model_or_param, "quant_state", None) is not None: + model_or_param_new = getattr(model_or_param.quant_state, attr, None) + model_or_param = model_or_param_new return model_or_param def to_public_fp32(model, state_dict, params_attr):