Skip to content

Fix TorchAOConfig skip layers #19265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/quantization/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,20 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
print(output)


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
torch._dynamo.reset()
model_name = "mobicham/Qwen2.5-VL-3B-Instruct_int8wo_ao"
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0") as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)

assert output
print(output)


if __name__ == "__main__":
pytest.main([__file__])
58 changes: 51 additions & 7 deletions vllm/model_executor/layers/quantization/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,30 @@
logger = init_logger(__name__)


def should_skip(prefix: str, skip_modules: list[str]) -> bool:
"""
Robust skipping logic:
should_skip("model.model.layers.1.q_proj",
["model.model.layers.1.q_proj"]) # True
should_skip("model.model.layers.10.o_proj", ["o_proj"]) -> True
should_skip("visual.model.layers.1.q_proj", ["visual"]) -> True
should_skip("model.model.layers.1.q_proj", ["layers.1"]) -> True
should_skip("model.model.layers.11.q_proj", ["layers.1"]) -> False
"""
for s in skip_modules:
if prefix == s:
return True
if f".{s}." in f".{prefix}.":
return True
return False


class TorchAOConfig(QuantizationConfig):
"""Config class for torchao."""

def __init__(self, torchao_config) -> None:
self.torchao_config = torchao_config
def __init__(self,
torchao_config,
skip_modules: Optional[list[str]] = None) -> None:
"""
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
Expand All @@ -36,6 +55,8 @@ def __init__(self, torchao_config) -> None:
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
self.torchao_config = torchao_config
self.skip_modules = skip_modules or []

def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
Expand Down Expand Up @@ -67,11 +88,28 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":

hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
assert hf_config is not None, "quant_type must be specified"
assert (len(hf_config) == 1 and "default" in hf_config
), "Expected only one key 'default' in quant_type dictionary"
assert len(hf_config) == 1 and "default" in hf_config, (
"Expected only one key 'default' in quant_type dictionary")
quant_type = hf_config["default"]
ao_config = config_from_dict(quant_type)
return cls(ao_config)

# Adds skipped modules defined in "modules_to_not_convert"
skip_modules = config.get("modules_to_not_convert", []) or []

# Adds skipped modules defined in "module_fqn_to_config"
_data = quant_type.get("_data", {})
if not isinstance(_data, dict):
_data = {}

module_fqn = _data.get("module_fqn_to_config", {})
if not isinstance(module_fqn, dict):
module_fqn = {}

for layer, layer_cfg in module_fqn.items():
if layer_cfg is None:
skip_modules.append(layer)

return cls(ao_config, skip_modules)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
Expand All @@ -80,13 +118,16 @@ def get_quant_method(self, layer: torch.nn.Module,

from torchao.quantization import ModuleFqnToConfig

if should_skip(prefix, self.skip_modules):
return UnquantizedLinearMethod()

module_fqn = prefix
if isinstance(self.torchao_config, ModuleFqnToConfig):
module_fqn_to_config = self.torchao_config.module_fqn_to_config
c = module_fqn_to_config.get(
module_fqn) or module_fqn_to_config.get("_default", None)
if c is not None:
current_torchao_config = TorchAOConfig(c)
current_torchao_config = TorchAOConfig(c, self.skip_modules)
return TorchAOLinearMethod(current_torchao_config)
else:
return UnquantizedLinearMethod()
Expand All @@ -108,8 +149,11 @@ def torchao_quantize_param_data(param: torch.Tensor,
"""
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_

assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear = torch.nn.Linear(1, 1, bias=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little surprised the returned subclass actually has the right metdata for copy_ the state dict

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as long as you set the rightin_features / out_features. I have been doing this trick for a long time and it saves a lot of time especially with large layers, it can make loading time about 5-10x faster. I also use it in my vllm loading logic.

dummy_linear.in_features = param.shape[1]
dummy_linear.out_features = param.shape[0]
dummy_linear.weight = param
quantize_(dummy_linear, torchao_config)
return dummy_linear.weight
Expand Down