Skip to content

Fix autoround CI with amp #2253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 24, 2025
2 changes: 2 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def autoround_quantize_entry(
}
enable_full_range = quant_config.enable_full_range
batch_size = quant_config.batch_size
amp = quant_config.amp
lr_scheduler = quant_config.lr_scheduler
enable_quanted_input = quant_config.enable_quanted_input
enable_minmax_tuning = quant_config.enable_minmax_tuning
Expand Down Expand Up @@ -636,6 +637,7 @@ def autoround_quantize_entry(
quant_config=weight_config,
enable_full_range=enable_full_range,
batch_size=batch_size,
amp=amp,
lr_scheduler=lr_scheduler,
enable_quanted_input=enable_quanted_input,
enable_minmax_tuning=enable_minmax_tuning,
Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,7 @@ def __init__(
act_dtype: Optional[str] = "int",
enable_full_range: bool = False,
batch_size: int = 8,
amp: bool = True,
lr_scheduler=None,
enable_quanted_input: bool = True,
enable_minmax_tuning: bool = True,
Expand Down Expand Up @@ -995,6 +996,7 @@ def __init__(
act_dtype (Optional[str]): Data type for activation quantization. Default is None.
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True).
lr_scheduler: The learning rate scheduler to be used.
enable_quanted_input (bool): Whether to use quantized input data (default is True).
enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True).
Expand Down Expand Up @@ -1042,6 +1044,7 @@ def __init__(
self.act_dtype = act_dtype
self.enable_full_range = enable_full_range
self.batch_size = batch_size
self.amp = amp
self.lr_scheduler = lr_scheduler
self.enable_quanted_input = enable_quanted_input
self.enable_minmax_tuning = enable_minmax_tuning
Expand Down
32 changes: 16 additions & 16 deletions test/3x/torch/quantization/weight_only/test_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def run_fn(model, dataloader):
else:
model(data)

@pytest.mark.skip(reason="SW-217321 pytorch inductor error")
@pytest.mark.skipif(is_habana_framework_installed(), reason="These tests are not supported on HPU for now.")
@pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed")
class TestAutoRoundCPU:
Expand Down Expand Up @@ -97,7 +96,7 @@ def setup_method(self, method):
@pytest.mark.parametrize("quant_lm_head", [True, False])
def test_autoround(self, quant_lm_head):
fp32_model = copy.deepcopy(self.gptj)
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32")
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32")
if quant_lm_head is False:
quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))
logger.info(f"Test AutoRound with config {quant_config}")
Expand All @@ -110,15 +109,15 @@ def test_autoround(self, quant_lm_head):
out = q_model(self.inp)[0]
assert torch.allclose(out, self.label, atol=1e-1)
assert "transformer.h.0.attn.k_proj" in q_model.autoround_config.keys()
assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
assert "scale_dtype" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]
assert isinstance(q_model.transformer.h[0].attn.k_proj, WeightOnlyLinear), "packing model failed."
if quant_lm_head is True:
assert isinstance(q_model.lm_head, WeightOnlyLinear), "quantization for lm_head failed."

def test_int4_dtype(self):
fp32_model = copy.deepcopy(self.gptj)
quant_config = AutoRoundConfig(dtype="int4", nsamples=32, seqlen=10, iters=10, scale_dtype="fp32")
quant_config = AutoRoundConfig(dtype="int4", nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32")
logger.info(f"Test AutoRound with config {quant_config}")

# prepare + convert API
Expand All @@ -129,14 +128,14 @@ def test_int4_dtype(self):
out = q_model(self.inp)[0]
assert torch.allclose(out, self.label, atol=1e-1)
assert "transformer.h.0.attn.k_proj" in q_model.autoround_config.keys()
assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
assert "scale_dtype" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]
assert isinstance(q_model.transformer.h[0].attn.k_proj, WeightOnlyLinear), "packing model failed."

def test_autoround_with_quantize_API(self):
gpt_j_model = copy.deepcopy(self.gptj)

quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32")
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32")
quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))

logger.info(f"Test AutoRound with config {quant_config}")
Expand All @@ -156,7 +155,7 @@ def test_save_and_load(self):
fp32_model = copy.deepcopy(self.gptj)
# known issue: scale_dtype="fp32" will cause accuracy gap between quantized model
# (using auto-round WeightOnlyLinear) and reloaded model (using INCWeightOnlyLinear)
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp16")
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp16")
# quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))
logger.info(f"Test AutoRound with config {quant_config}")

Expand Down Expand Up @@ -185,11 +184,11 @@ def test_conv1d(self):
from transformers import GPT2Model, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("sshleifer/tiny-gpt2")
model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2")
model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2", use_cache=False)
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors="pt")
out1 = model(**encoded_input)[0]
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32")
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32")
model = prepare(model=model, quant_config=quant_config)
run_fn(model, self.dataloader)
q_model = convert(model)
Expand All @@ -207,7 +206,7 @@ def test_utils(self):
fp32_model = copy.deepcopy(self.gptj)
to_quant_block_names = get_multimodal_block_names(fp32_model, quant_vision=True)
quant_config = AutoRoundConfig(
nsamples=32, seqlen=10, iters=10, scale_dtype="fp16", to_quant_block_names=to_quant_block_names
nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp16", to_quant_block_names=to_quant_block_names
)
logger.info(f"Test AutoRound with config {quant_config}")
device = detect_device("auto")
Expand All @@ -222,6 +221,7 @@ def test_utils(self):
assert torch.allclose(out, self.label, atol=1e-1)
assert isinstance(q_model.transformer.h[0].attn.k_proj, WeightOnlyLinear), "packing model failed."

@pytest.mark.skipif(Version(auto_round.__version__) <= Version("0.5.1"), reason="visual layer_name not processed.")
def test_mllm(self):
input = torch.randn(1, 32)
from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration
Expand All @@ -237,7 +237,7 @@ def test_mllm(self):
model=model,
tokenizer=tokenizer,
image_processor=None,
dataset="liuhaotian/llava_conv_58k",
dataset="NeelNanda/pile-10k",
extra_data_dir=None,
seqlen=32,
batch_size=1,
Expand Down Expand Up @@ -266,13 +266,13 @@ def test_mllm(self):
model = prepare(model=model, quant_config=quant_config)
run_fn(model, dataloader)
q_model = convert(model)
assert isinstance(q_model.model.layers[0].mlp.up_proj, WeightOnlyLinear), "model quantization failed."
assert isinstance(q_model.language_model.layers[0].mlp.up_proj, WeightOnlyLinear), "model quantization failed."

# def test_autoround_format_export(self):
# from neural_compressor.torch.quantization import load
# from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear
# gpt_j_model = copy.deepcopy(self.gptj)
# quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32", export_format="auto_round:gptq")
# quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32", export_format="auto_round:gptq")
# logger.info(f"Test AutoRound with config {quant_config}")
# model = prepare(model=gpt_j_model, quant_config=quant_config)
# run_fn(model, self.dataloader)
Expand Down Expand Up @@ -366,7 +366,7 @@ def test_autoround_w4a8(self):
@pytest.mark.parametrize("quant_lm_head", [True, False])
def test_autoround(self, quant_lm_head):
fp32_model = copy.deepcopy(self.tiny_llama_model)
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, act_dtype="fp32", scale_dtype="fp32")
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, act_dtype="fp32", amp=False ,scale_dtype="fp32")
if quant_lm_head is False:
quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))
logger.info(f"Test AutoRound with config {quant_config}")
Expand All @@ -386,7 +386,7 @@ def test_autoround(self, quant_lm_head):
def test_int4_dtype(self):
fp32_model = copy.deepcopy(self.tiny_llama_model)
quant_config = AutoRoundConfig(
dtype="int4", nsamples=32, seqlen=10, iters=10, act_dtype="fp32", scale_dtype="fp32"
dtype="int4", nsamples=32, seqlen=10, iters=10, act_dtype="fp32", amp=False ,scale_dtype="fp32"
)
logger.info(f"Test AutoRound with config {quant_config}")

Expand All @@ -402,7 +402,7 @@ def test_int4_dtype(self):
def test_autoround_with_quantize_API(self):
model = copy.deepcopy(self.tiny_llama_model)

quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, act_dtype="fp32", scale_dtype="fp32")
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, act_dtype="fp32", amp=False ,scale_dtype="fp32")
quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))

logger.info(f"Test AutoRound with config {quant_config}")
Expand Down