diff --git a/src/sparseml/modifiers/quantization/gptq/base.py b/src/sparseml/modifiers/quantization/gptq/base.py index 43bc596d849..833fa284531 100644 --- a/src/sparseml/modifiers/quantization/gptq/base.py +++ b/src/sparseml/modifiers/quantization/gptq/base.py @@ -50,6 +50,7 @@ class GPTQModifier(Modifier): - LayerCompressor.revert_layer_wrappers() + :param actorder: Whether to use activation reordering or not :param sequential_update: Whether or not to update weights sequentially by layer, True saves on GPU memory :param targets: list of layer names to compress during GPTQ, or '__ALL__' diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py index ded28b4123b..9f660b987fb 100644 --- a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -88,6 +88,7 @@ def fasterprune( Run pruning and quantization(if applicable) on the layer up to the target sparsity value. + :param actorder: Flag to apply activation reordering :param blocksize: Number of columns to compress in one pass :param percdamp: Amount of dampening to apply to H, as a fraction of the diagonal norm @@ -127,6 +128,9 @@ def fasterprune( self.H = torch.linalg.cholesky(self.H, upper=True) Hinv = self.H + actorder = False + invperm = None + # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -144,6 +148,7 @@ def fasterprune( for i in range(count): w = W1[:, i] d = Hinv1[i, i] + q = w.clone() if hasattr(self.layer, "weight_fake_quant"): @@ -156,18 +161,42 @@ def fasterprune( else: q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype) q = torch.dequantize(q) + elif hasattr(self.layer, "quantization_scheme"): quant_scheme = self.layer.quantization_scheme + actorder = quant_scheme.weights.actorder if quant_scheme.weights is not None: + + if actorder: + perm = torch.argsort(torch.diag(self.H), descending=True) + W = W[:, perm] + self.H = self.H[perm][:, perm] + invperm = torch.argsort(perm) + scale = self.layer.weight_scale zero_point = self.layer.weight_zero_point + + group_size = quant_scheme.weights.group_size + if group_size is None or group_size == -1: + group_size = self.layer.weight.shape[1] + + if actorder: + indices = torch.arange(self.columns, device=invperm.device) + g_idx = (perm[indices] // group_size).to(dtype=torch.int32) + g_idx = g_idx[invperm] + self.layer.weight_g_idx.data = g_idx + else: + indices = torch.arange( + self.columns, device=W.device, dtype=torch.int32 + ) + g_idx = indices // group_size + from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import ( fake_quantize, ) strategy = quant_scheme.weights.strategy - if strategy == QuantizationStrategy.TENSOR: q = fake_quantize( q, @@ -189,11 +218,21 @@ def fasterprune( input_dim_group = ( column_idx // quant_scheme.weights.group_size ) - # Since we're only applying quantization to a slice, this # ends up being a channelwise application altered_qargs = copy(quant_scheme.weights) altered_qargs.strategy = QuantizationStrategy.CHANNEL + + # apply g_idx + if g_idx is not None: + # scale and zp already transformed by group_size + # extract first index of group_idze + indices_to_extract = torch.arange( + 0, g_idx.shape[0], group_size + ) + scale = scale[:, g_idx[indices_to_extract]] + zero_point = zero_point[:, g_idx[indices_to_extract]] + q = fake_quantize( q, scale[:, input_dim_group], @@ -224,6 +263,9 @@ def fasterprune( _LOGGER.info("time %.2f" % (time.time() - tick)) _LOGGER.info("error %.2f" % torch.sum(Losses).item()) + if actorder: + W = W[:, invperm] + if isinstance(self.layer, transformers.Conv1D): W = W.t() W = W.reshape(final_shape).to(final_dtype) diff --git a/src/sparseml/modifiers/utils/layer_compressor.py b/src/sparseml/modifiers/utils/layer_compressor.py index e5a36f77278..eb0b51cf269 100644 --- a/src/sparseml/modifiers/utils/layer_compressor.py +++ b/src/sparseml/modifiers/utils/layer_compressor.py @@ -134,6 +134,8 @@ def revert_layer_wrappers(self): def compress(self): """ Apply compression to each wrapped submodule in the layer + + :param: actorder: flag to apply activation reordering """ @torch.no_grad()