Highlights
We are excited to announce the 0.13.0 release of torchao! This release adds support for numerous QAT improvements, faster mxfp8 pretraining and more!
Simpler Multi-step QAT API (#2629)
We added a new, simpler, multi-step QAT API that uses only a single config. Now users can specify the target post-training quantization (PTQ) config as the base config and we will automatically infer the correct fake quantize configs to use!
from torchao.quantization import (
quantize_,
Int8DynamicActivationInt4WeightConfig
)
from torchao.quantization.qat import QATConfig
# prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
qat_config = QATConfig(base_config, step="prepare")
quantize_(m, qat_config)
# train (not shown)
# convert
quantize_(m, QATConfig(base_config, step="convert"))
For more advanced use cases, users can continue to specify specific FakeQuantizeConfigs as before:
# prepare
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
# train and convert (not shown)
(Prototype) NVFP4 and FP8 QAT (#2735, #2666)
We generalized QAT to support FP8 and NVFP4 use cases. You can try them out as follows:
from torchao.quantization import (
quantize_,
Float8DynamicActivationInt4WeightConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
)
from torchao.prototype.mx_formats import NVFP4InferenceConfig
from torchao.quantization.qat import QATConfig
# Pick a base config
base_config = Float8DynamicActivationInt4WeightConfig() # or
base_config = Float8DynamicActivationInt8WeightConfig() # or
base_config = NVFP4InferenceConfig()
# prepare
qat_config = QATConfig(base_config, step="prepare")
quantize_(m, qat_config)
# train (not shown)
# convert
quantize_(m, QATConfig(base_config, step="convert"))
Users can also use the more specific FakeQuantizeConfigs for more advanced use cases, e.g.:
from torchao.quantization import PerRow
from torchao.quantization.qat import Float8FakeQuantizeConfig
from torchao.prototype.qat import NVFP4FakeQuantizeConfig
act_config = Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerRow())
weight_config = NVFP4FakeQuantizeConfig(use_per_tensor_scale=True)
# prepare
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
# train and convert (not shown)
(prototype) 1.2x MXFP8 dense pretraining speedups with torchtitan
We landed performance improvements (such as a faster to_mx dim1 cast) to our prototype MXFP8 training APIs, and we now achieve a 1.2x speedup vs bf16 on pretraining LLaMa 3 8B on NVIDIA B200. Please see our training benchmarks README for more information.
torchao float8 training now integrated into axolotl!
You can now use torchao.float8
directly from axolotl to achieve finetuning QPS e2e speedups of up to 1.1x on 3B parameter models (docs, release notes).
BC Breaking
Float8DynamicActivationFloat8WeightConfig
and Float8WeightOnlyConfig
version bump to 2 (#2650)
We updated the implementation for float8 Tensor, so bumps the default version from 1 to 2 for these two configs.
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "torchao-testing/opt-125m-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev"
quantized_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="bfloat16",
device_map="cuda",
)
/data/users/jerryzh/ao/torchao/core/config.py:249: UserWarning: Stored version is not the same as current default version of the config: stored_version=1, current_version=2, please check the deprecation warning
warnings.warn(
/data/users/jerryzh/ao/torchao/dtypes/floatx/float8_layout.py:113: UserWarning: Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2649 for more details
warnings.warn(
Suggestion: upgrade torchao to 0.13 and later and generate the checkpoint again:
quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
Or download the checkpoint again (please let us know if the checkpoint is not updated)
Please see #2649 for more details around the deprecation.
QAT API Changes (#2628, #2641)
On a high level, the following existing APIs are deprecated and replaced by these new ones. Although this is technically BC-breaking due to typing changes, it will not affect most users as old classes are kept around for now. They are planned to be removed in the next release, however.
IntXQuantizationAwareTrainingConfig -> QATConfig
FromIntXQuantizationAwareTrainingConfig -> QATConfig
FakeQuantizeConfig -> IntxFakeQuantizeConfig
FakeQuantizer -> IntxFakeQuantizer
Please see #2630 and the latest QAT README for more information on how to migrate.
Remove old change_linear_weights_to_*
APIs (#2721)
The following old quantization APIs no longer work and are removed:
change_linear_weights_to_int8_dqtensors(model)
change_linear_weights_to_int8_woqtensors(model)
change_linear_weights_to_int4_woqtensors(model)
Please use the quantize_ API with the following configs instead:
quantize_(model, Int8WeightOnlyConfig())
quantize_(model, Int4WeightOnlyConfig())
Deprecations
Deprecate old TORCH_VERSION variables (#2719)
The following variables are deprecated and will be removed in the next release:
TORCH_VERSION_AT_LEAST_2_2
TORCH_VERSION_AT_LEAST_2_3
TORCH_VERSION_AT_LEAST_2_4
TORCH_VERSION_AT_LEAST_2_5
TORCH_VERSION_AT_LEAST_2_6
TORCH_VERSION_AT_LEAST_2_7
TORCH_VERSION_AT_LEAST_2_8
TORCH_VERSION_AFTER_2_2
TORCH_VERSION_AFTER_2_3
TORCH_VERSION_AFTER_2_4
TORCH_VERSION_AFTER_2_5
Drop support for PyTorch 2.5 and before (#2720)
torchao only supports the latest 3 versions of PyTorch. Please upgrade to PyTorch 2.6.0+ if you were using an older version of PyTorch.
New Features
- New multi-step QAT API (#2629)
- Add float8 FakeQuantizeConfig and FakeQuantizer (#2735)
- (prototype) Add NVFP4 QAT (#2666)
Improvements
- Add StretchedUnifTorchaoQuantizer (#2576)
- Allow symmetric_no_clipping_error for KleidiAI kernels, update Readme and validate Kleidi INT4 quantization path (#2570)
- Enable powers of 2 cast in float8 rowwise_with_gw_hp recipe (#2677)
- Don't call erase if node is already erased in batch norm fusion. (#2716)
- Generalize FakeQuantizer beyond intx (#2714)
- Allow pattern replacement to ignore literals (#2519)
- Replace
export_for_training
withtorch.export.export
(#2724) - Allow no quantization during QATConfig convert (#2694)
- Int4 sparse marlin tensor (#2771)
- Remove group_size arg in Float8DynamicActivationInt4WeightConfig (#2779)
- Fix batch norm folding in
prepare_pt2e
for multiple conv->BN chains sharing the same conv weights (#2795) - Add Float8Tensor (#2463)
- (prototype) Allow per-group quantizers in QuantOptimizer, fix state_dict (#2743)
- (prototype) SpinQuant support split qkv (prototype) (#2547)
- (prototype) Make AWQ more general (#2400)
- (prototype) MX training
- (prototype) MoE training
- Mxfp8 emulated grouped gemm (#2626)
- Add differentiable mxfp8 grouped gemm with dynamic quant (forward pass) (#2627)
- Support for 2d-2d emulated mxfp8 grouped gemm (#2632)
- Backward pass for differentiable mxfp8 grouped gemm with dynamic quant (#2639)
- torch.compile support for ScaledGroupedMMTensor (#2509)
- Assert expert weights are column-major; preserve subclass with transpose (#2663)
- set token group alignment size to 16 for fp8 training test (#2678)
- Make scaling type configurable for MoE training (#2642)
- use smaller block sizes for per group scaling kernels to improve perf (#2668)
- add llama4 benchmarking script (#2669)
- add fp8 rowwise kernels for expert weights (#2696)
- add bench script for fp8 rowwise kernels and update autotune configs (#2697)
- integrate rowwise expert quant kernel (#2698)
- work around wrap_triton bug by using normal custom ops instead for fp8 rowwise kernels (#2734)
- fix scaling type bug; refactor distributed tests (#2749)
- use llama4 shapes for kernel benchmarks (#2756)
- remove duplicate benchmark script (#2762)
- refactor to share benchmarking and profiling utils (#2767)
- add memory bandwidth calculations to kernel benchmarking scripts (#2769)
- update bench script to compare fp8 dynamic quant scaled_grouped_mm fwd+bwd against bf16 (#2765)
- Float8 blockwise training (prototype)
Bug Fixes
- Fix autocast handling for float8 training rowwise recipes (#2587)
- NVFP4 -> Use more of e4m3 range for block_scales (#2604)
- Handle the case when param groups are passed to optimizer (#2606)
- Fix bc breakage flex path (#2652)
- Fix FSDP2 breakage in nightly (#2684)
- When replacing literals with placeholders lists are always converted to (#2518)
- Don't learn zero points for symmetric quantization (#2739)
- fix ROCM build for newer hipblaslt BC-breaking change (#2510)
- Fix missing QuantOptimizer methods (#2770)
- Fix float8 + int4 QAT (#2851)
- Allowlist WeightWithDynamicFloat8CastTensor for deserialization for checkpointing (#2573)
Performance
- Fix float8 rowwise inference perf with torch.compile (#2672)
- Add CUDA kernel for MXFP8 dim1 casting (#2513, #2550)
- Extend the MX cast benchmark to include casting to mxfp4 (#2693)
Documentation
- Add QLoRA and FP8 to finetuning tutorial (part 2) (#2542)
- Clean up QAT API surface + add separate API ref (#2567)
- Update float8 README with AMD MI300X benchmark results (#2736)
- Update float8 README.md with more recent e2e performance numbers (#2774, #2580)
- Update quantization overview and contributor guide doc (#2723)
- add e2e training benchmark results to mx_formats README.md (#2777)
- Update paper link readme (#2563)
- Minor improvements to OpenVINOQuantizer (#2581)
- Update README with PEFT integration + installation (#2559)
Developers
- Bump cutlass version to 4.1.0 (#2589)
- Fix git repo url in citation (#2599)
- Simplify Float8Linear (#2594, #2595)
- Convert quantization internal methods to private (#2568)
- Reference representation of dqlinear int4 for xnnpack (#2520)
- Refactors to align with new tensor subclass design
- Add all fbgemm kernel Tensors into Int4WeightOnlyConfig and Float8DynamicActivationInt4WeightConfig (#2474)
- Add support for float8 activation for Int4PreshuffledTensor (#2437)
- Align Int4Tensor implementation details with the design of Float8Tensor (#2687)
- Support
optional_tensor_names
in TorchAOBaseTensor (#2710) - Update Int4PreshuffledTensor to align with implementation details of the Float8Tensor (#2738)
- Nvfp4 tensor: switch to using
qdata
(#2787) - Nvfp4 tensor: switch to TorchAOBaseTensor (#2788)
- Nvfp4 tensor: refactor weight-only vs dynamic quant (#2790)
- Mxtensor: make data argument first and rename to
qdata
(#2804) - Mxtensor: inherit from TorchAOBaseTensor (#2805)
- Mxtensor: refactor activation quant to use direct logic (#2806)
- Support more ops in TorchAOBaseTensor (#2609)
New Contributors
- @wdvr made their first contribution in #2548
- @carmocca made their first contribution in #2539
- @gausah-arm made their first contribution in #2570
- @daniil-lyakhov made their first contribution in #2581
- @zeshengzong made their first contribution in #2599
- @amdfaa made their first contribution in #2662
- @chowarfb made their first contribution in #2657
- @abeakkas made their first contribution in #2716
- @subhankarpal made their first contribution in #2795
Full Changelog: v0.12.0...v0.13.0-rc1