diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 33c6f5588..8cb53a8e6 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -18,7 +18,6 @@ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # Placeholder for all non-transformer models registered in QEfficient - # custom warning for the better logging experience warnings.formatwarning = custom_format_warning diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index a20fc4cb3..e503a057f 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -120,61 +120,109 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: class SplitGateUpWeightsTransform(PytorchTransform): """ - split fused Gate+Up weights and copy into the model + Split fused Gate+Up weights and copy into the model. + Handles both standard MoE models and GptOss models. For every transformer layer inside `model`: - • expects .experts.gate_up_proj in the *source* `sd` - • copies halves into - .experts.gate_proj <-- Gate [E,H,I] - .experts.up_proj <-- Up [E,H,I] + • expects .experts.gate_up_proj in the *source* `sd` + • copies halves into + .experts.gate_proj <-- Gate [E,H,I] + .experts.up_proj <-- Up [E,H,I] + + Handles both interleaved weights (GptOss) and concatenated weights (standard MoE). + Also handles bias terms when present. """ @classmethod def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: transformed = False model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__ - if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS: return model, transformed model_tmp = model.language_model if hasattr(model, "language_model") else model - num_layers = len(model_tmp.model.layers) delete_fused_key = True sd = model_tmp.state_dict() + for layer_idx in range(num_layers): + # Determine if this is a GptOss model or standard MoE model + is_gpt_oss = hasattr(model_tmp.model.layers[layer_idx], "mlp") + # ---- build the textual prefix once per layer ---------- - prefix = f"model.layers.{layer_idx}.feed_forward.experts." + if is_gpt_oss: + prefix = f"model.layers.{layer_idx}.mlp.experts." + experts = model_tmp.model.layers[layer_idx].mlp.experts + else: + prefix = f"model.layers.{layer_idx}.feed_forward.experts." + experts = model_tmp.model.layers[layer_idx].feed_forward.experts fused_key = prefix + "gate_up_proj" gate_key = prefix + "gate_proj" up_key = prefix + "up_proj" - # ---- split [E,H,2I] → two [E,H,I] tensors ---------------------- - fused = sd[fused_key] # [E, H, 2I] (no .weight here) + # Check if we have bias terms (GptOss case) + has_bias = fused_key + "_bias" in sd + if has_bias: + fused_bias_key = fused_key + "_bias" + gate_bias_key = gate_key + "_bias" + up_bias_key = up_key + "_bias" + + # ---- split weights based on model type ---------------------- + fused = sd[fused_key] # [E, H, 2I] E, H, two_I = fused.shape - ffn_dim = two_I // 2 - gate, up = fused.split(ffn_dim, dim=-1) # views – no copy - experts = model_tmp.model.layers[layer_idx].feed_forward.experts + if is_gpt_oss: + # For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...] + gate = fused[..., ::2] # [E, H, I] - even indices + up = fused[..., 1::2] # [E, H, I] - odd indices + else: + # For standard MoE, gate/up are concatenated: [gate, up] + ffn_dim = two_I // 2 + gate, up = fused.split(ffn_dim, dim=-1) # views – no copy + + # Copy weights to model experts.gate_proj.data.copy_(gate) experts.up_proj.data.copy_(up) + # Handle bias if present + if has_bias: + fused_bias = sd[fused_bias_key] # [E, 2I] + + if is_gpt_oss: + gate_bias = fused_bias[..., ::2] # [E, I] - even indices + up_bias = fused_bias[..., 1::2] # [E, I] - odd indices + else: + ffn_dim = fused_bias.shape[-1] // 2 + gate_bias, up_bias = fused_bias.split(ffn_dim, dim=-1) + + experts.gate_proj_bias.data.copy_(gate_bias) + experts.up_proj_bias.data.copy_(up_bias) + # ---- update the state-dict so load_state_dict sees the right keys sd[gate_key] = gate sd[up_key] = up + if has_bias: + sd[gate_bias_key] = gate_bias + sd[up_bias_key] = up_bias + + # Delete fused keys if delete_fused_key: del sd[fused_key] + if has_bias: + del sd[fused_bias_key] - logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})") + logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})") transformed = True if hasattr(model, "language_model"): model.language_model = model_tmp else: model = model_tmp + return model, transformed -VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"} +# Keep the existing list of supported models +VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"} diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index bbd937d52..72a055dde 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -537,3 +537,102 @@ def update( ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out + + +# This is a hack for now, until we get to merging this code with HybridCache class, +# We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and +# ours are made to work with AIC +class QEffHybridCacheForGPTOSS: + def __init__(self, config, batch_size, max_cache_len, sliding_window_len): + self.max_cache_len = max_cache_len + self.batch_size = batch_size + self.sliding_window_len = sliding_window_len + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + @classmethod + def from_legacy_cache( + cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "HybridCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls( + config, + batch_size=past_key_values[0][0].shape[0], + max_cache_len=past_key_values[1][0].shape[2], + sliding_window_len=past_key_values[0][0].shape[2], + ) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or len(self.key_cache[layer_idx]) == 0 # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + is_sliding_layer = cache_kwargs.get("is_sliding") + sliding_window = cache_kwargs.get("sliding_window") + + if is_sliding_layer: + kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window) + else: + kv_position_ids = position_ids + + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Original Gather + ctx_len = self.key_cache[layer_idx].shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index c692d1beb..5337b44f5 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -185,6 +185,7 @@ ] ) +# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc. DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} # Define a transformers layers to QEff layers dictionary diff --git a/QEfficient/transformers/models/gpt_oss/__init__.py b/QEfficient/transformers/models/gpt_oss/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/transformers/models/gpt_oss/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py new file mode 100644 index 000000000..bc460fea6 --- /dev/null +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -0,0 +1,711 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssAttention, + GptOssConfig, + GptOssDecoderLayer, + GptOssExperts, + GptOssForCausalLM, + GptOssMLP, + GptOssModel, + GptOssRotaryEmbedding, + repeat_kv, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs + +from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE + + +class QEffGptOssExperts(GptOssExperts): + def __qeff_init__(self): + self.gate_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) + self.up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) + self.gate_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + + +class QEffGptOssMLP(GptOssMLP): + def alt_forward(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Gate and Up projections + gate = (hidden @ W_g) + b_g # [T, I] + up = (hidden @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out = (intermediate @ W_d) + b_d # [T, H] + + # Apply routing weights and accumulate + masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) + expert_out += masked_down + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + # ------------------- Gather based, weights as activation approach --------------- + def forward_weights_as_activation(self, hidden_states): + bs, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + router_top_value, router_indices = torch.topk(router_logits, self.router.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + + # GATHER - collect weights for selected experts + gate_up_proj = self.experts.gate_up_proj[router_indices.flatten()] + gate_up_proj_bias = self.experts.gate_up_proj_bias[router_indices.flatten()] + down_proj = self.experts.down_proj[router_indices.flatten()] + down_proj_bias = self.experts.down_proj_bias[router_indices.flatten()] + + # Apply Chosen Experts (without routing weights first) + # expert_in = hidden_states.repeat_interleave(self.router.top_k, dim=0) + # expert_in = expert_in.view(-1, 1, self.experts.hidden_size) + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) + expert_in = ( + hidden_states.unsqueeze(1) + .expand(-1, self.router.top_k, -1) + .contiguous() + .view(-1, 1, self.experts.hidden_size) + ) + + gate_up = torch.bmm(expert_in, gate_up_proj) + gate_up_proj_bias.unsqueeze(1) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + + # Apply activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + glu = gate * torch.sigmoid(gate * self.experts.alpha) + gated_output = (up + 1) * glu + + experts_out = torch.bmm(gated_output, down_proj) + down_proj_bias.unsqueeze(1) + experts_out = experts_out.view(bs * seq_len, self.router.top_k, self.experts.hidden_size) + + # Apply routing weights AFTER expert computation (This is before on Llama4) + experts_out = experts_out * router_top_value.unsqueeze(-1) + experts_out = experts_out.sum(dim=1) + + return experts_out, router_logits + + # ------------------- Gather based, weights as activation approach, With Seperate Gate, up Projections --------------- + def forward(self, hidden_states): + # print("Seperate Split, Up, Gate Projections") + bs, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + router_top_value, router_indices = torch.topk(router_logits, self.router.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + + # GATHER - collect weights for selected experts (separate gate and up projections) + gate_proj = self.experts.gate_proj[router_indices.flatten()] + gate_proj_bias = self.experts.gate_proj_bias[router_indices.flatten()] + up_proj = self.experts.up_proj[router_indices.flatten()] + up_proj_bias = self.experts.up_proj_bias[router_indices.flatten()] + down_proj = self.experts.down_proj[router_indices.flatten()] + down_proj_bias = self.experts.down_proj_bias[router_indices.flatten()] + + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) + expert_in = ( + hidden_states.unsqueeze(1) + .expand(-1, self.router.top_k, -1) + .contiguous() + .view(-1, 1, self.experts.hidden_size) + ) + + # Apply gate and up projections separately using bmm + gate = torch.bmm(expert_in, gate_proj) + gate_proj_bias.unsqueeze(1) + up = torch.bmm(expert_in, up_proj) + up_proj_bias.unsqueeze(1) + + # Apply activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + gated_output = (up + 1) * glu + + # Down projection + experts_out = torch.bmm(gated_output, down_proj) + down_proj_bias.unsqueeze(1) + experts_out = experts_out.view(bs * seq_len, self.router.top_k, self.experts.hidden_size) + + # Apply routing weights AFTER expert computation + experts_out = experts_out * router_top_value.unsqueeze(-1) + experts_out = experts_out.sum(dim=1) + + return experts_out, router_logits + + def optimized_moe_forward(self, hidden_states: torch.Tensor): + B, S, H = hidden_states.shape + T = B * S + hidden_states = hidden_states.view(T, H) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + + # Top-k selection + top_w, selected_experts = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + # Creating experts mask and routing weights masked + awesome_experts_mask_1 = ( + torch.nn.functional.one_hot(selected_experts[:, 0], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_2 = ( + torch.nn.functional.one_hot(selected_experts[:, 1], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_3 = ( + torch.nn.functional.one_hot(selected_experts[:, 2], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_4 = ( + torch.nn.functional.one_hot(selected_experts[:, 3], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + + gateupout1 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout2 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout3 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout4 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + + # Gate and Up projections + gate = (hidden_states @ W_g) + b_g # [T, I] + up = (hidden_states @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + gateupout1 += torch.where(awesome_experts_mask_1[e], intermediate, torch.zeros_like(gateupout1)) + gateupout2 += torch.where(awesome_experts_mask_2[e], intermediate, torch.zeros_like(gateupout2)) + gateupout3 += torch.where(awesome_experts_mask_3[e], intermediate, torch.zeros_like(gateupout3)) + gateupout4 += torch.where(awesome_experts_mask_4[e], intermediate, torch.zeros_like(gateupout4)) + + concat_down = torch.zeros((self.router.top_k, T, H)) + concat_mask = torch.cat( + ( + awesome_experts_mask_1.unsqueeze(0), + awesome_experts_mask_2.unsqueeze(0), + awesome_experts_mask_3.unsqueeze(0), + awesome_experts_mask_4.unsqueeze(0), + ), + dim=0, + ) + + concat_gateout = torch.cat( + (gateupout1.unsqueeze(0), gateupout2.unsqueeze(0), gateupout3.unsqueeze(0), gateupout4.unsqueeze(0)), dim=0 + ) + + for e in range(self.experts.num_experts): + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Down projection + down_out = (concat_gateout @ W_d) + b_d # [T, H] + + concat_down += torch.where(concat_mask[:, e, :], down_out, torch.zeros_like(concat_down)) + + downout1, downout2, downout3, downout4 = concat_down[0], concat_down[1], concat_down[2], concat_down[3] + hidden_states = ( + downout1 * top_w[:, 0].unsqueeze(-1) + + downout2 * top_w[:, 1].unsqueeze(-1) + + downout3 * top_w[:, 2].unsqueeze(-1) + + downout4 * top_w[:, 3].unsqueeze(-1) + ).reshape(B, S, H) + + # original shape [B, S, H] + return hidden_states, router_logits + + +# Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology +class QEffGptOssRotaryEmbedding(GptOssRotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: GptOssConfig, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class QEffGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + # kv_seq_len = key_states.shape[-2] + + # kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if self.sliding_window is not None: + attention_mask = sliding_mask + else: + attention_mask = attention_mask + + attention_interface: Callable = eager_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffGptOssDecoderLayer(GptOssDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + # alth, _ = self.mlp.alt_forward(hidden_states) + hidden_states = residual + hidden_states + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffGptOssModel(GptOssModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len) + sliding_mask = _create_causal_mask( + position_ids=position_ids, + target_length=past_key_values.sliding_window_len, + sliding_window=past_key_values.sliding_window_len, + ) + + hidden_states = inputs_embeds + # position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class QEffGptOssForCausalLM(GptOssForCausalLM): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GptOssForCausalLM + + >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + logits = logits.float() + + return MoeCausalLMOutputWithPast( + loss=None, + aux_loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def get_pkv_dynamic_axes( + self, + ): + pkv_dynamic_axes = [] + for layer_type in self.config.layer_types: + if layer_type == "sliding_attention": + pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"}) + elif layer_type == "full_attention": + pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"}) + return pkv_dynamic_axes diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2a00577f2..6ac72c630 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -51,6 +51,7 @@ AwqToMatmulNbitsTransform, FP8DeQuantLinearToLinearTransform, GPTQToMatmulNbitsTransform, + Mxfp4GptOssExpertDequantizeTransform, ) from QEfficient.utils import ( constants, @@ -2027,6 +2028,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, FP8DeQuantLinearToLinearTransform, + Mxfp4GptOssExpertDequantizeTransform, CustomOpsTransform, KVCacheTransform, SplitGateUpWeightsTransform, @@ -2283,10 +2285,20 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names.append(f"past_{kv}.{i}_RetainedState") else: + # HACK: create common function for this including above if condition code + pkv_dynamic_axes = ( + self.model.get_pkv_dynamic_axes() if hasattr(self.model, "get_pkv_dynamic_axes") else pkv_dynamic_axes + ) + pkv_dynamic_axes = ( + [pkv_dynamic_axes] * self.model.config.num_hidden_layers + if isinstance(pkv_dynamic_axes, dict) + else pkv_dynamic_axes + ) + for i in range(self.num_layers): for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] output_names.append(f"past_{kv}.{i}_RetainedState") if self.continuous_batching: @@ -2636,6 +2648,11 @@ def compile( for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + # HACK for now + if self.model.config.model_type == "gpt_oss": + for spec in specializations: + spec.update({"sliding_window": 128}) + qpc_path = self._compile( onnx_path=onnx_path, compile_dir=compile_dir, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index c910ab387..1a1a1d275 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -51,6 +51,15 @@ GPTBigCodeForCausalLM, GPTBigCodeModel, ) +from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssAttention, + GptOssDecoderLayer, + GptOssExperts, + GptOssForCausalLM, + GptOssMLP, + GptOssModel, + GptOssRMSNorm, +) from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel from transformers.models.granite.modeling_granite import ( GraniteAttention, @@ -231,6 +240,14 @@ QEffGPTBigCodeForCausalLM, QEffGPTBigCodeModel, ) +from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( + QEffGptOssAttention, + QEffGptOssDecoderLayer, + QEffGptOssExperts, + QEffGptOssForCausalLM, + QEffGptOssMLP, + QEffGptOssModel, +) from QEfficient.transformers.models.gptj.modeling_gptj import ( QEffGPTJAttention, QEffGPTJBlock, @@ -408,6 +425,7 @@ class CustomOpsTransform(ModuleMappingTransform): GraniteRMSNorm: CustomRMSNormAIC, PixtralRMSNorm: CustomRMSNormAIC, GraniteMoeRMSNorm: CustomRMSNormAIC, + GptOssRMSNorm: CustomRMSNormAIC, Qwen3MoeRMSNorm: CustomRMSNormAIC, Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, Olmo2RMSNorm: CustomRMSNormAIC, @@ -480,6 +498,13 @@ class KVCacheTransform(ModuleMappingTransform): Gemma3TextModel: QEffGemma3TextModel, Gemma3ForCausalLM: QEffGemma3ForCausalLMModel, Gemma3ForConditionalGeneration: QEffGemma3ForConditionalGeneration, + # GPT_OSS + GptOssAttention: QEffGptOssAttention, + GptOssDecoderLayer: QEffGptOssDecoderLayer, + GptOssModel: QEffGptOssModel, + GptOssForCausalLM: QEffGptOssForCausalLM, + GptOssMLP: QEffGptOssMLP, + GptOssExperts: QEffGptOssExperts, # Granite GraniteModel: QEffGraniteModel, GraniteForCausalLM: QEffGraniteForCausalLM, diff --git a/QEfficient/transformers/quantizers/__init__.py b/QEfficient/transformers/quantizers/__init__.py index d647b73a6..dfadc00ef 100644 --- a/QEfficient/transformers/quantizers/__init__.py +++ b/QEfficient/transformers/quantizers/__init__.py @@ -4,3 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers + +__all__ = ["replace_transformers_quantizers"] diff --git a/QEfficient/transformers/quantizers/auto.py b/QEfficient/transformers/quantizers/auto.py index ba204e419..d73909211 100644 --- a/QEfficient/transformers/quantizers/auto.py +++ b/QEfficient/transformers/quantizers/auto.py @@ -11,7 +11,8 @@ from transformers.quantizers.quantizer_awq import AwqQuantizer from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer from transformers.quantizers.quantizer_gptq import GptqHfQuantizer -from transformers.utils.quantization_config import AwqConfig, CompressedTensorsConfig, GPTQConfig +from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer +from transformers.utils.quantization_config import AwqConfig, CompressedTensorsConfig, GPTQConfig, Mxfp4Config from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer from QEfficient.transformers.quantizers.quantizer_compressed_tensors import ( @@ -21,30 +22,35 @@ QEffFP8Quantizer, ) from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer +from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4Config, QEffMxfp4HfQuantizer QEFF_AUTO_QUANTIZER_MAPPING = { "awq": QEffAwqQuantizer, "gptq": QEffGPTQQuantizer, "compressed-tensors": QEffCompressedTensorsFP8Quantizer, "fp8": QEffFP8Quantizer, + "mxfp4": QEffMxfp4HfQuantizer, } QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = { "awq": QEffAwqConfig, "gptq": QEffGPTQConfig, "compressed-tensors": QEffCompressedTensorsConfig, "fp8": QEffFP8Config, + "mxfp4": QEffMxfp4Config, } DUPLICATE_AUTO_QUANTIZER_MAPPING = { "awq": AwqQuantizer, "gptq": GptqHfQuantizer, "compressed-tensors": CompressedTensorsHfQuantizer, "fp8": None, + "mxfp4": Mxfp4HfQuantizer, } DUPLICATE_AUTO_QUANTIZATION_CONFIG_MAPPING = { "awq": AwqConfig, "gptq": GPTQConfig, "compressed-tensors": CompressedTensorsConfig, "fp8": None, + "mxfp4": Mxfp4Config, } diff --git a/QEfficient/transformers/quantizers/quant_transforms.py b/QEfficient/transformers/quantizers/quant_transforms.py index 0427bca37..69d6380f0 100644 --- a/QEfficient/transformers/quantizers/quant_transforms.py +++ b/QEfficient/transformers/quantizers/quant_transforms.py @@ -7,13 +7,19 @@ import torch from torch import nn +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts from QEfficient.base.pytorch_transforms import ModuleMutatorTransform from QEfficient.customop.matmulnbits import QuantLinearORT from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear -from QEfficient.transformers.quantizers.quantizer_utils import dequantize_gptq, unpack_weights +from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4GptOssExperts +from QEfficient.transformers.quantizers.quantizer_utils import ( + convert_moe_packed_tensors, + dequantize_gptq, + unpack_weights, +) class AwqToMatmulNbitsTransform(ModuleMutatorTransform): @@ -115,3 +121,28 @@ def mutate(cls, original_module, parent_module): if original_module.bias is not None: dequant_linear_layer.bias = torch.nn.Parameter(original_module.bias.float()) return dequant_linear_layer + + +class Mxfp4GptOssExpertDequantizeTransform(ModuleMutatorTransform): + """ + Used to dequantize the weights of an Mxfp4GptOssExpert module and replace with transformers GptOssExperts with dequantized weights + """ + + _match_class = QEffMxfp4GptOssExperts + + @classmethod + def mutate(cls, original_module, parent_module): + dequant_module = GptOssExperts(original_module.config) + dequant_module.gate_up_proj = torch.nn.Parameter( + convert_moe_packed_tensors( + original_module.gate_up_proj_blocks, original_module.gate_up_proj_scales, dtype=torch.float32 + ) + ) + dequant_module.down_proj = torch.nn.Parameter( + convert_moe_packed_tensors( + original_module.down_proj_blocks, original_module.down_proj_scales, dtype=torch.float32 + ) + ) + dequant_module.gate_up_proj_bias = original_module.gate_up_proj_bias + dequant_module.down_proj_bias = original_module.down_proj_bias + return dequant_module diff --git a/QEfficient/transformers/quantizers/quantizer_mxfp4.py b/QEfficient/transformers/quantizers/quantizer_mxfp4.py new file mode 100644 index 000000000..2ffba1bea --- /dev/null +++ b/QEfficient/transformers/quantizers/quantizer_mxfp4.py @@ -0,0 +1,155 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import re +from typing import Optional + +import torch +import torch.nn as nn +from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer +from transformers.utils.quantization_config import Mxfp4Config + +from QEfficient.transformers.quantizers.quantizer_utils import convert_moe_packed_tensors, get_keys_to_not_convert +from QEfficient.utils.logging_utils import logger + + +class QEffMxfp4GptOssExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + + self.gate_up_proj_blocks = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), + requires_grad=False, + ) + self.gate_up_proj_scales = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8), + requires_grad=False, + ) + self.gate_up_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False + ) + + self.down_proj_blocks = nn.Parameter( + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), + requires_grad=False, + ) + self.down_proj_scales = nn.Parameter( + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), + requires_grad=False, + ) + self.down_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False + ) + self.alpha = 1.702 + self.limit = 7.0 + + self.gate_up_proj_precision_config = None + self.down_proj_precision_config = None + + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + gate_up_proj = convert_moe_packed_tensors( + self.gate_up_proj_blocks, self.gate_up_proj_scales, dtype=torch.float32 + ) + down_proj = convert_moe_packed_tensors(self.down_proj_blocks, self.down_proj_scales, dtype=torch.float32) + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[1] + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, gate_up_proj) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), down_proj) + next_states = next_states + self.down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + return next_states + + +def should_convert_module(current_key_name, patterns): + current_key_name_str = ".".join(current_key_name) + if not any( + re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns + ): + return True + return False + + +class QEffMxfp4Config(Mxfp4Config): + """ + Currently there is not need to change the implementation of Mxfp4Config + This is placeholder for future when we would want to change this + """ + + pass + + +class QEffMxfp4HfQuantizer(Mxfp4HfQuantizer): + def validate_environment(self, *args, **kwargs): + return True + + def update_torch_dtype(self, torch_dtype): + if torch_dtype not in [None, torch.float32]: + logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") + return None + + def _process_model_before_weight_loading( + self, + model: torch.nn.Module, + keep_in_fp32_modules: Optional[list[str]] = None, + **kwargs, + ): + self.modules_to_not_convert = get_keys_to_not_convert(model) + self.modules_to_not_convert = ( + ["lm_head"] if self.modules_to_not_convert is None else self.modules_to_not_convert + ) + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + self.modules_to_not_convert = list(set(self.modules_to_not_convert)) + config = model.config + + # -- Defining local method as it uses lot of local variables -- + def _replace_with_mxfp4_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, + ): + if current_key_name is None: + current_key_name = [] + + for name, module in model.named_children(): + current_key_name.append(name) + if not should_convert_module(current_key_name, modules_to_not_convert): + current_key_name.pop(-1) + continue + if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize: + model._modules[name] = QEffMxfp4GptOssExperts(config) + has_been_replaced = True + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_mxfp4_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + current_key_name.pop(-1) + return model, has_been_replaced + + _replace_with_mxfp4_linear( + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config + ) + model.config.quantization_config = self.quantization_config diff --git a/QEfficient/transformers/quantizers/quantizer_utils.py b/QEfficient/transformers/quantizers/quantizer_utils.py index a318fb8e4..881357c54 100644 --- a/QEfficient/transformers/quantizers/quantizer_utils.py +++ b/QEfficient/transformers/quantizers/quantizer_utils.py @@ -378,3 +378,71 @@ def repack_zeros(qzeros, bits): break qzeros = qzeros.T return qzeros + + +FP4_VALUES = [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def convert_moe_packed_tensors( + blocks, + scales, + *, + dtype: torch.dtype = torch.bfloat16, + rows_per_chunk: int = 32768 * 1024, +) -> torch.Tensor: + """ + reference for this function is taken from: https://github.com/huggingface/transformers/tree/main/src/transformers/models/gpt_oss#L98 + """ + import math + + scales = scales.to(torch.int32) - 127 + + assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}" + + lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + del idx_lo, idx_hi, blk, exp + + out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + out = out.to(dtype).permute(0, 2, 1).contiguous() + return out diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index eb1f7c8e6..94bc6ac1b 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -91,9 +91,13 @@ def prepare_pytorch_inputs(self): inputs["batch_index"] = torch.arange(1).view(-1, 1) past_key_values = [] + sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] for i in range(self.n_layer): - past_key = torch.zeros((self.padding_shape), dtype=torch.float32) - past_value = torch.zeros((self.padding_shape), dtype=torch.float32) + pad_shape = ( + sliding_padding_shape if self.config.layer_types[i] == "sliding_attention" else self.padding_shape + ) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) pkv = (past_key, past_value) past_key_values.append(pkv) inputs["past_key_values"] = tuple(past_key_values) diff --git a/examples/gpt_oss.py b/examples/gpt_oss.py new file mode 100644 index 000000000..fd00f88fd --- /dev/null +++ b/examples/gpt_oss.py @@ -0,0 +1,35 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM + +model_id = "openai/gpt-oss-20b" + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +onnx_model_path = qeff_model.export() +qpc_path = qeff_model.compile( + prefill_seq_len=1, # Currently we can get best perf using PL=1 i.e. decode-only model, prefill optimizations are being worked on. + ctx_len=256, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=8, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, +) +print(f"qpc path is {qpc_path}") +streamer = TextStreamer(tokenizer) +exec_info = qeff_model.generate( + tokenizer, + prompts="Who is your creator? and What all you are allowed to do?", + device_id=[0, 1, 2, 3], +) diff --git a/pyproject.toml b/pyproject.toml index ea3c3405d..dbff208a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ dependencies = [ "transformers==4.55.0", "huggingface-hub==0.34.0", "hf_transfer==0.1.9", - "peft==0.13.2", - "datasets==2.20.0", + "peft", + "datasets", "fsspec==2023.6.0", "multidict==6.0.4", "urllib3<2", diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py index bdc15519e..9c7b75c1b 100644 --- a/tests/transformers/test_causal_lm.py +++ b/tests/transformers/test_causal_lm.py @@ -33,6 +33,7 @@ ("starcoder2", 256, 2, 4, 128, 512, 127, {}), ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gpt_oss", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] configs = [