Skip to content

Adding CETT-based thresholding #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.trainer_utils import set_seed

from src.activation_capture import ActivationCaptureTraining
from src.activation_capture import Hook

# Setup logging
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -120,14 +120,14 @@ def process_batch(
hidden_states_dict = {}
mlp_activations_dict = {}
for layer_idx in range(num_layers):
hidden_state = model.activation_capture.get_hidden_states(layer_idx)[0]
hidden_state = model.activation_capture.mlp_activations[Hook.IN][layer_idx][0]
hidden_states_dict[layer_idx] = (
hidden_state.view(-1, hidden_state.shape[-1])
.cpu()
.numpy()
.astype(np.float32)
)
mlp_activation = model.activation_capture.get_gate_activations(layer_idx)
mlp_activation = model.activation_capture.mlp_activations[Hook.ACT][layer_idx]
mlp_activations_dict[layer_idx] = (
mlp_activation[0]
.view(-1, mlp_activation.shape[-1])
Expand Down Expand Up @@ -172,8 +172,8 @@ def generate_dataset(
model = model.to(device)

model.eval()
model.activation_capture = ActivationCaptureTraining(model)
model.activation_capture.register_hooks()
model.activation_capture = model.ACTIVATION_CAPTURE(model)
model.activation_capture.register_hooks(hooks=[Hook.IN, Hook.ACT])

# Get model dimensions
hidden_dim = model.config.hidden_size
Expand Down
23 changes: 7 additions & 16 deletions measure_contextual_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers.trainer_utils import set_seed

import matplotlib.pyplot as plt
from src.activation_capture import ActivationCaptureDefault
from src.activation_capture import Hook

# Setup logging
logging.basicConfig(level=logging.INFO)
Expand All @@ -28,16 +28,14 @@ def __init__(self, model, tokenizer, device):
self.tokenizer = tokenizer
self.device = device

model.activation_capture = ActivationCaptureDefault(model)
model.activation_capture.register_hooks()
model.activation_capture = model.ACTIVATION_CAPTURE(model)
model.activation_capture.register_hooks(hooks=[Hook.ACT])
self.num_layers = len(self.model.activation_capture.get_layers())

self.reset_buffers()

def reset_buffers(self):
self.mlp_sparsity = {}
self.mlp_sparsity["gate"] = defaultdict(list)
self.mlp_sparsity["up"] = defaultdict(list)
self.mlp_sparsity = defaultdict(list)
self.num_seqs = 0

def process_batch(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
Expand All @@ -54,26 +52,19 @@ def process_batch(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):

# Compute sparsity
for layer_idx in range(self.num_layers):
sparsity_masks_gate = (
self.model.activation_capture.get_gate_activations(layer_idx) <= 0
)
sparsity_masks_up = (
self.model.activation_capture.get_up_activations(layer_idx) <= 0
sparsity_masks = (
self.model.activation_capture.mlp_activations[Hook.ACT][layer_idx] <= 0
)

# Naive sparsity computation
self.mlp_sparsity["gate"][layer_idx].append(
sparsity_masks_gate.float().mean().item()
)
self.mlp_sparsity["up"][layer_idx].append(
sparsity_masks_up.float().mean().item()
sparsity_masks.float().mean().item()
)

# Level of sparsity after union over batch dim
# union_sparsity_mask = sparsity_masks.any(dim=0)
# self.union_sparsity[batch_size][layer_idx].append(union_sparsity_mask.float().mean().item())

# TODO: Add HNSW sparsity computation for both attn heads and mlp neurons
# TODO: Compute union sparsity over multiple different batch sizes

# Clear GPU tensors from capture to free memory
Expand Down
181 changes: 65 additions & 116 deletions src/activation_capture.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,91 @@
from typing_extensions import override
import torch.nn.functional as F
from abc import ABC, abstractmethod

from enum import Enum
from typing import List

class ActivationCapture(ABC):
class Hook(Enum):
IN = "IN"
ACT = "ACT"
UP = "UP"
OUT = "OUT"


class ActivationCapture():
"""Helper class to capture activations from model layers."""
has_gate_proj: bool
has_up_proj: bool
hooks_available: List[Hook]

def __init__(self, model):
self.model = model
self.mlp_activations = {}
self.mlp_activations = {
hook: {} for hook in self.hooks_available
}
self.handles = []

@abstractmethod
def _register_gate_hook(self, layer_idx, layer):
pass
def _register_in_hook(self, layer_idx, layer):
def hook(module, input, output):
# Just detach, don't clone or move to CPU yet
self.mlp_activations[Hook.IN][layer_idx] = input[0].clone().detach()
return output
handle = layer.mlp.register_forward_hook(hook)
return handle

def _register_act_hook(self, layer_idx, layer):
def hook(module, input, output):
# Just detach, don't clone or move to CPU yet
self.mlp_activations[Hook.ACT][layer_idx] = input[0].clone().detach()
return output
handle = layer.mlp.act_fn.register_forward_hook(hook)
return handle

@abstractmethod
def _register_up_hook(self, layer_idx, layer):
pass
def hook(module, input, output):
# Just detach, don't clone or move to CPU yet
self.mlp_activations[Hook.UP][layer_idx] = input[0].clone().detach()
return output
handle = layer.mlp.down_proj.register_forward_hook(hook)
return handle

def _register_out_hook(self, layer_idx, layer):
def hook(module, input, output):
# Just detach, don't clone or move to CPU yet
self.mlp_activations[Hook.OUT][layer_idx] = output.clone().detach()
return output
handle = layer.mlp.register_forward_hook(hook)
return handle

@abstractmethod
def get_layers(self):
pass


@abstractmethod
def get_gate_activations(self, layer_idx):
"""Get combined MLP activations for a layer."""
pass
return self.model.get_decoder().layers

def register_hooks(self):
def register_hooks(self, hooks=(Hook.ACT, Hook.UP, Hook.OUT)):
"""Register forward hooks to capture activations."""
# Clear any existing hooks
self.remove_hooks()

# Hook into each transformer layer
for i, layer in enumerate(self.get_layers()):
# Capture MLP gate activations (after activation function)
if self.has_gate_proj:
handle = self._register_gate_hook(i, layer)
for i, layer in enumerate(self.get_layers()):
# Hooks capturing inputs to the MLP layer
if Hook.IN in hooks and Hook.IN in self.hooks_available:
handle = self._register_in_hook(i, layer)
if handle is not None:
self.handles.append(handle)

# Also capture up_proj activations
if self.has_up_proj:

# Hooks capturing inputs to the activation function
if Hook.ACT in hooks and Hook.ACT in self.hooks_available:
handle = self._register_act_hook(i, layer)
if handle is not None:
self.handles.append(handle)

# Hooks capturing inputs to the down projection
if Hook.UP in hooks and Hook.UP in self.hooks_available:
handle = self._register_up_hook(i, layer)
if handle is not None:
self.handles.append(handle)

# Hooks capturing the final MLP output
if Hook.OUT in hooks and Hook.OUT in self.hooks_available:
handle = self._register_out_hook(i, layer)
if handle is not None:
self.handles.append(handle)


def remove_hooks(self):
"""Remove all registered hooks."""
Expand All @@ -59,91 +96,3 @@ def remove_hooks(self):
def clear_captures(self):
"""Clear captured activations."""
self.mlp_activations = {}



class ActivationCaptureDefault(ActivationCapture):
"""Helper class to capture activations from model layers."""
has_gate_proj: bool = True
has_up_proj: bool = True

def get_layers(self):
return self.model.get_decoder().layers

def _create_mlp_hook(self, layer_idx, proj_type):
def hook(module, input, output):
key = f"{layer_idx}_{proj_type}"
# Just detach, don't clone or move to CPU yet
self.mlp_activations[key] = output.clone().detach()
return output
return hook

def _register_gate_hook(self, layer_idx, layer):
handle = layer.mlp.gate_proj.register_forward_hook(
self._create_mlp_hook(layer_idx, 'gate')
)
return handle

def _register_up_hook(self, layer_idx, layer):
handle = layer.mlp.up_proj.register_forward_hook(
self._create_mlp_hook(layer_idx, 'up')
)
return handle

def get_gate_activations(self, layer_idx):
gate_key = f"{layer_idx}_gate"
if gate_key in self.mlp_activations:
gate_act = self.mlp_activations[gate_key]
return F.silu(gate_act)
return None

def get_up_activations(self, layer_idx):
up_key = f"{layer_idx}_up"
if up_key in self.mlp_activations:
up_act = self.mlp_activations[up_key]
return up_act
return None

class ActivationCaptureTraining(ActivationCaptureDefault):
"""Additional Hidden State capture for training dataset generation"""
def __init__(self, model):
super().__init__(model)
self.hidden_states = {}

def _create_hidden_state_hook(self, layer_idx, layer):
def hook(module, args, kwargs, output):
# args[0] is the input hidden states to the layer
if len(args) > 0:
# Just detach, don't clone or move to CPU yet
self.hidden_states[layer_idx] = args[0].clone().detach()
return output
return hook

def _register_hidden_state_hook(self, layer_idx, layer):
handle = layer.register_forward_hook(
self._create_hidden_state_hook(layer_idx, layer),
with_kwargs=True
)
return handle

@override
def clear_captures(self):
"""Clear captured activations."""
super().clear_captures()
self.hidden_states = {}

@override
def register_hooks(self):
"""Register forward hooks to capture activations."""
# Clear any existing hooks
super().register_hooks()
# Hook into each transformer layer
for i, layer in enumerate(self.get_layers()):
# Capture hidden states before MLP
handle = self._register_hidden_state_hook(i, layer)
if handle is not None:
self.handles.append(handle)

def get_hidden_states(self, layer_idx):
"""Get hidden states for a layer."""
return self.hidden_states[layer_idx]
Loading