diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 3de15369915..8f1c8487746 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -6,8 +6,7 @@ from torch.nn import Module from torch.nn.parameter import Parameter -from vllm._custom_ops import (cutlass_scaled_fp4_mm, - cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) +from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -21,12 +20,109 @@ from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types logger = init_logger(__name__) QUANT_ALGOS = ["FP8", "NVFP4"] KV_CACHE_QUANT_ALGOS = ["FP8"] +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() + +kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], + dtype=torch.float32) + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_to_dtype(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype) + + +def cast_to_fp4(x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +def get_reciprocal(x): + if isinstance(x, torch.Tensor): + return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + elif isinstance(x, (float, int)): + return 0.0 if x == 0 else 1.0 / x + else: + raise TypeError("Input must be a float, int, or a torch.Tensor.") + + +def ref_nvfp4_quant(x, global_scale, block_size): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + x = torch.reshape(x, (m, n // block_size, block_size)) + vec_max = torch.max(torch.abs(x), dim=-1, + keepdim=True)[0].to(torch.float32) + scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) + + scaled_x = x.to(torch.float32) * output_scale + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + # both outputs are float32 + return cast_to_fp4(clipped_x), scale.squeeze(-1) + class ModelOptFp8Config(QuantizationConfig): """Config class for ModelOpt FP8.""" @@ -193,7 +289,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 100 + return 89 @classmethod def get_config_filenames(cls) -> List[str]: @@ -262,10 +358,10 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config): self.quant_config = quant_config - self.cutlass_nvfp4_supported = cutlass_fp4_supported() - if not self.cutlass_nvfp4_supported: - raise ValueError("Current platform does not support NVFP4" - " quantization. Please use Blackwell and above.") + # self.cutlass_nvfp4_supported = cutlass_fp4_supported() + # if not self.cutlass_nvfp4_supported: + # raise ValueError("Current platform does not support NVFP4" + # " quantization. Please use Blackwell and above.") def create_weights( self, @@ -386,25 +482,45 @@ def apply( output_dtype = x.dtype # for input only the contracting dimension has a constraint. - x_m, _ = x.shape - w_n, _ = layer.weight.shape + x_m, x_k = x.shape + w_n, w_k = layer.weight.shape + # print(f"{x.shape=}") + # print(f"{layer.weight.shape=}") output_shape = [x_m, w_n] + block_size = 16 + + # quantize input to (FP4 and interleaved block scale) + # x_global_scale = layer.input_scale + x_global_scale = 1 / layer.input_scale + # x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) + x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale, block_size) + # x_blockscale = self.swizzle_blockscale(x_blockscale) + # print(f"{x_fp4.shape=}") + # print(f"{x_blockscale.shape=}") + + # dequantize input + x_fp4 = x_fp4.reshape(x_m, x_k // block_size, block_size) + x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale + x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) + del x_fp4, x_blockscale + + # dequantize weight + w_fp4 = layer.weight.data.view(torch.uint8) + w_blockscale = layer.weight_scale_swizzled.data + w_global_scale = layer.weight_scale_2 + # print(f"{w_fp4.shape=}") + # print(f"{w_blockscale.shape=}") + # print(f"{w_global_scale.shape=}") + w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, + output_dtype, x.device, + block_size).to(output_dtype) + # print(f"{w_dq.shape=}") + + # matmul + out = torch.matmul(x_dq, w_dq.t()) + del x_dq, w_dq + # print(f"{out.shape=}") - # quantize BF16 or FP16 to (FP4 and interleaved block scale) - s_quant = 1 / layer.input_scale - x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) - - # validate dtypes of quantized input, input block scale, - # weight and weight_blockscale - assert (x_fp4.dtype == torch.uint8) - assert (layer.weight.dtype == torch.uint8) - assert (x_blockscale.dtype == torch.float8_e4m3fn) - assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn) - assert (layer.alpha.dtype == torch.float32) - - out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, layer.alpha, - output_dtype) if bias is not None: out = out + bias return out.view(*output_shape)