Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

activation ordering #2316

Closed
wants to merge 15 commits into from
1 change: 1 addition & 0 deletions src/sparseml/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class GPTQModifier(Modifier):
- LayerCompressor.revert_layer_wrappers()


:param actorder: Whether to use activation reordering or not
:param sequential_update: Whether or not to update weights sequentially by layer,
True saves on GPU memory
:param targets: list of layer names to compress during GPTQ, or '__ALL__'
Expand Down
46 changes: 44 additions & 2 deletions src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def fasterprune(
Run pruning and quantization(if applicable) on the layer up to the target
sparsity value.

:param actorder: Flag to apply activation reordering
:param blocksize: Number of columns to compress in one pass
:param percdamp: Amount of dampening to apply to H, as a fraction of the
diagonal norm
Expand Down Expand Up @@ -127,6 +128,9 @@ def fasterprune(
self.H = torch.linalg.cholesky(self.H, upper=True)
Hinv = self.H

actorder = False
invperm = None

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
Expand All @@ -144,6 +148,7 @@ def fasterprune(
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]

q = w.clone()

if hasattr(self.layer, "weight_fake_quant"):
Expand All @@ -156,18 +161,42 @@ def fasterprune(
else:
q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype)
q = torch.dequantize(q)

elif hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
actorder = quant_scheme.weights.actorder
if quant_scheme.weights is not None:

if actorder:
perm = torch.argsort(torch.diag(self.H), descending=True)
W = W[:, perm]
self.H = self.H[perm][:, perm]
invperm = torch.argsort(perm)

scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point

group_size = quant_scheme.weights.group_size
if group_size is None or group_size == -1:
group_size = self.layer.weight.shape[1]

if actorder:
indices = torch.arange(self.columns, device=invperm.device)
g_idx = (perm[indices] // group_size).to(dtype=torch.int32)
g_idx = g_idx[invperm]
self.layer.weight_g_idx.data = g_idx
else:
indices = torch.arange(
self.columns, device=W.device, dtype=torch.int32
)
g_idx = indices // group_size

from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import (
fake_quantize,
)

strategy = quant_scheme.weights.strategy

if strategy == QuantizationStrategy.TENSOR:
q = fake_quantize(
q,
Expand All @@ -189,11 +218,21 @@ def fasterprune(
input_dim_group = (
column_idx // quant_scheme.weights.group_size
)

# Since we're only applying quantization to a slice, this
# ends up being a channelwise application
altered_qargs = copy(quant_scheme.weights)
altered_qargs.strategy = QuantizationStrategy.CHANNEL

# apply g_idx
if g_idx is not None:
# scale and zp already transformed by group_size
# extract first index of group_idze
indices_to_extract = torch.arange(
0, g_idx.shape[0], group_size
)
scale = scale[:, g_idx[indices_to_extract]]
zero_point = zero_point[:, g_idx[indices_to_extract]]

q = fake_quantize(
q,
scale[:, input_dim_group],
Expand Down Expand Up @@ -224,6 +263,9 @@ def fasterprune(
_LOGGER.info("time %.2f" % (time.time() - tick))
_LOGGER.info("error %.2f" % torch.sum(Losses).item())

if actorder:
W = W[:, invperm]

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.reshape(final_shape).to(final_dtype)
Expand Down
2 changes: 2 additions & 0 deletions src/sparseml/modifiers/utils/layer_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def revert_layer_wrappers(self):
def compress(self):
"""
Apply compression to each wrapped submodule in the layer

:param: actorder: flag to apply activation reordering
"""

@torch.no_grad()
Expand Down
Loading