diff --git a/mixtral-moe/README.md b/mixtral-moe/README.md index cf5e9d9b..bfe1c2b1 100644 --- a/mixtral-moe/README.md +++ b/mixtral-moe/README.md @@ -4,7 +4,7 @@ ## Downloading Weights ```bash -export MODEL_REPO=mistralai/Mixtral-8x7B-v0.1 +export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 python scripts/download.py --repo_id $MODEL_REPO python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO ``` diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 9aa076b6..43e7c585 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -25,7 +25,7 @@ def device_sync(device): torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future - +torch._dynamo.config.capture_scalar_outputs = True # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -52,7 +52,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non return probs def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): - probs = logits_to_probs(logits[0, -1], temperature, top_k) + probs = logits_to_probs(logits[:, -1], temperature, top_k) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs @@ -74,11 +74,13 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc next_token, next_prob = decode_one_token( model, cur_token, input_pos, **sampling_kwargs ) + next_token, next_prob = next_token.clone(), next_prob.clone() + input_pos += 1 new_tokens.append(next_token.clone()) callback(new_tokens[-1]) new_probs.append(next_prob.clone()) - cur_token = next_token.view(1, -1) + cur_token = next_token return new_tokens, new_probs @@ -91,6 +93,7 @@ def generate( model: Transformer, prompt: torch.Tensor, max_new_tokens: int, + batch_size: int, *, interactive: bool, callback = lambda x: x, @@ -99,32 +102,30 @@ def generate( """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. """ + device, dtype = prompt.device, prompt.dtype + + + T = prompt.size(-1) + max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + new_tokens = max_seq_length - T + + # duplicate prompt for batch_size + prompt = prompt.repeat(batch_size, 1) # create an empty tensor of the expected final shape and fill in the current tokens - T = prompt.size(0) - T_new = T + max_new_tokens - if interactive: - max_seq_length = 350 - else: - max_seq_length = min(T_new, model.config.block_size) + seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device) + seq[:, :T] = prompt - device, dtype = prompt.device, prompt.dtype with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) - # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = prompt - seq = empty input_pos = torch.arange(0, T, device=device) - - next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) - seq[T] = next_token + next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs) + seq[:, T] = next_token.squeeze() input_pos = torch.tensor([T], device=device, dtype=torch.int) - - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) - seq[T + 1:] = torch.cat(generated_tokens) + generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1) return seq @@ -144,8 +145,12 @@ def _load_model(checkpoint_path, device, precision, use_tp): simple_quantizer = WeightOnlyBit8QuantHandler(model, torch.int8) model = simple_quantizer.convert_for_runtime() - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) - model.load_state_dict(checkpoint, assign=True) + try: + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + except: + model = Transformer.from_name(checkpoint_path.parent.name) + if use_tp: from tp import apply_tp @@ -162,6 +167,7 @@ def main( interactive: bool = False, num_samples: int = 5, max_new_tokens: int = 100, + batch_size: int = 1, top_k: int = 200, temperature: float = 0.8, checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), @@ -172,8 +178,7 @@ def main( ) -> None: """Generates text samples based on a pre-trained Transformer model and tokenizer. """ - assert checkpoint_path.is_file(), checkpoint_path - + # assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) @@ -202,13 +207,81 @@ def main( torch.manual_seed(1234) model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) + + + import torchao + from torchao.quantization import quantize_, Int8WeightOnlyConfig + + + def filter(model, fqn): + return isinstance(model, torch.nn.Linear) and "gate" not in fqn + + quantize_(model, Int8WeightOnlyConfig(), filter_fn=filter) + + + from torchao.quantization.quant_primitives import MappingType + from torchao.dtypes import to_affine_quantized_intx + + def moe_filter(module, fqn): + return "MOEFeedForwardAOQuantizable" in str(type(module)) + + def cond_ffn_filter(module, fqn): + return "ConditionalFeedForwardAOQuantizable" in str(type(module)) + + def quant_convert_fn(module, config): + def quant_tensor(weight): + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + block_size = [1 for x in range(param.dim())] + block_size[-1] = param.shape[-1] + block_size = tuple(block_size) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + ) + return new_weight + assert "ConditionalFeedForwardAOQuantizable" in str(type(module)) + assert hasattr(module, "w1") + assert hasattr(module, "w2") + assert hasattr(module, "w3") + + group_size = None if config.group_size is None else config.group_size + for weight_attr in ["w1", "w2", "w3"]: + param = getattr(module, weight_attr) + new_param = quant_tensor(param) + new_param = torch.nn.Parameter(new_param, requires_grad=False) + setattr(module, weight_attr, new_param) + del param + return module + + from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + + # _replace_with_custom_fn_if_matches_filter( + # model, + # quant_convert_fn, + # cond_ffn_filter, + # extra_args=(Int8WeightOnlyConfig(),) + # ) + + + if compile: torch._inductor.config.assert_indirect_indexing = False + # torch._dynamo.config.capture_dynamic_output_shape_ops = True global decode_one_token, prefill - decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + if batch_size > 1: # MoE code has graph break for multi token path so can't fullgraph compile + # decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead") + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead") + else: + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) - # Uncomment to squeeze more perf out of prefill if args.compile_prefill: prefill = torch.compile(prefill, fullgraph=True, dynamic=True) @@ -255,6 +328,7 @@ def callback(x): model, encoded, max_new_tokens, + batch_size, interactive=interactive, callback=callback, temperature=temperature, @@ -272,16 +346,19 @@ def callback(x): t = time.perf_counter() - t0 if not interactive: - print(tokenizer.decode(y.tolist())) + print(tokenizer.decode(y[0].tolist())) else: print() - tokens_generated = y.size(0) - prompt_length + tokens_generated = y.size(-1) - prompt_length tokens_sec = tokens_generated / t aggregate_metrics['tokens_per_sec'].append(tokens_sec) print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") - print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") + tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() + print(f"Average tokens/sec: {tokpersec:.2f}") + if batch_size > 1: + print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}") print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") @@ -291,8 +368,10 @@ def callback(x): parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') - parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') + # parser.add_argument('--num_samples', type=int, default=1, help='Number of samples.') + parser.add_argument('--num_samples', type=int, default=2, help='Number of samples.') parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with') parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') @@ -303,6 +382,6 @@ def callback(x): args = parser.parse_args() main( - args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, + args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device ) diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index 9249ac9d..e930f37c 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -52,7 +52,7 @@ def from_name(cls, name: str): transformer_configs = { - "Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), + "Mixtral-8x7B-Instruct-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), } class KVCache(nn.Module): @@ -108,6 +108,8 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: x = self.tok_embeddings(idx) for i, layer in enumerate(self.layers): + # if i>2: + # break x = layer(x, input_pos, freqs_cis, mask) x = self.norm(x) logits = self.output(x) @@ -122,13 +124,15 @@ class TransformerBlock(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.attention = Attention(config) - self.block_sparse_moe = MOEFeedForward(config) + self.block_sparse_moe = MOEFeedForwardAOQuantizable(config) + # self.block_sparse_moe = MOEFeedForward(config) self.ffn_norm = RMSNorm(config.dim, config.norm_eps) self.attention_norm = RMSNorm(config.dim, config.norm_eps) def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) - out = h + self.block_sparse_moe(self.ffn_norm(h)) + moe_out = self.block_sparse_moe(self.ffn_norm(h)) + out = h + moe_out return out @@ -258,3 +262,113 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: x_out2 = x_out2.flatten(3) return x_out2.type_as(x) + + +# T tokens +# E experts +# D dim +# I intermediate dim +# A activated experts +# T'(e) tokens for expert e + +class MOEFeedForwardAOQuantizable(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForwardAOQuantizable(config) + self.dim = config.dim + self.num_activated_experts = config.num_activated_experts + def forward(self, x: Tensor) -> Tensor: + batch_size = x.shape[0] + x = x.view(-1, self.dim) # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights = F.softmax(scores, dim=-1) + expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] + out = self.cond_ffn(x, expert_indices, expert_weights, self.num_activated_experts) + return out.reshape(batch_size, -1, self.dim) + + +class ConditionalFeedForwardAOQuantizable(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D + self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) # E, D, I + self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D + self.num_experts = config.num_experts + def forward( + self, x: Tensor, # T, D + expert_indices: Tensor, # T, A + expert_weights: Tensor, # T, A + num_activated_experts: int, + ) -> Tensor: + num_tokens, dim = x.shape + num_token_activations = num_tokens * num_activated_experts + + if x.shape[0]==1: #only 1 token (can be done without graph breaks when compiled) + outs = [] + expert_indices=expert_indices.squeeze() + # collect used experts + w1 = self.w1[expert_indices] + w2 = self.w2[expert_indices] + w3 = self.w3[expert_indices] + + # run token through each expert + for index in range(num_activated_experts): + cur_out = F.linear( F.silu(F.linear(x, w1[index])) * F.linear(x, w3[index]), w2[index]) + outs.append(cur_out) + + # combine outputs + final_out = (torch.cat(outs, dim=0) * expert_weights.view(-1,1)).sum(dim=0).unsqueeze(-1) + return final_out + else: + expert_list = [x for x in range(self.num_experts)] + + # shuffle tokens into groups for each expert + ordered_token_activations = expert_indices.view(-1).argsort(stable=True) # [A] + ordered_token_indices = ordered_token_activations.div(num_activated_experts).floor().to(torch.int64) # [T] + + num_tokens_per_expert = torch.histc(expert_indices, bins=self.num_experts+1, min=-1, max=self.num_experts) # [E+1] (added leading 0 so can be used for indexing) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0) # [E+1] + + # needed to pull this into a function to apply this decorator since compile doesn't like it + # @torch._dynamo.disable() + # def group_tokens_by_expert(x, ordered_token_indices, cum_tokens_per_expert, expert_list): + # token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ... + # tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert] + # return tokens_grouped_by_expert + # tokens_grouped_by_expert = group_tokens_by_expert(x, ordered_token_indices, cum_tokens_per_expert, expert_list) + + @torch._dynamo.disable() + def group_tokens_by_expert(x, ordered_token_indices, cum_tokens_per_expert, expert_list): + token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + + # token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ... + token_indices_per_expert = group_tokens_by_expert(x, ordered_token_indices, cum_tokens_per_expert, expert_list) + tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert] + + + # calculate outputs for each expert + outs = [] + for cur_x, expert in zip(tokens_grouped_by_expert,expert_list): + # if x.shape[0]<24: + # import fbvscode; fbvscode.set_trace() + + w1=self.w1[expert] # I, D + w2=self.w2[expert] # D, I + w3=self.w3[expert] # I, D + + cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # [T'(e), D] + outs.append(cur_out) + + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_token_activation_weights = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1) # [T*A, 1] + weighted_ordered_outs = ordered_outs*ordered_token_activation_weights # [T*A, D] + + # sum weighted token-activation outputs together for each token + final_out = torch.zeros_like(x) # [T, D] + final_out = final_out.scatter_add(dim=0, index=ordered_token_indices.unsqueeze(-1).expand(num_token_activations,dim), src=weighted_ordered_outs) + return final_out diff --git a/mixtral-moe/run.sh b/mixtral-moe/run.sh new file mode 100644 index 00000000..c92b11f1 --- /dev/null +++ b/mixtral-moe/run.sh @@ -0,0 +1,26 @@ +export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 + +# python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth +# echo "1" +# echo "python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 +# " +# python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 +# echo "2" +# echo "python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 +# " +# python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 +# echo "3" +# echo "python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 +# " +# python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 + +python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --compile + + +# python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 --compile --profile "no_q_profile" + +# quant reduced layers +# Time for inference 2: 2.25 sec total, 88.82 tokens/sec +# Bandwidth achieved: 8296.64 GB/s +# Average tokens/sec: 163.61 +# Memory used: 94.12 GB