From 311eac1d7aab04c41870153fff58353458998c10 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Wed, 7 Aug 2024 07:19:46 +0000 Subject: [PATCH 1/8] Support cogvlm. Optimize cogvlm performance. Patch cogvlm language part. Remove some redundant code. Remove some changes. Remove some variables. feat: change infer_ext ops function param order (#2) feat: support ascend qwen2 and qwen2_moe (#6) * feat: support ascend qwen2 and qwen2_moe * fix: fix ascend mixtral ascend: align attention mask to 32bytes (#7) fix attn args (#9) fix: expand shape of attn_mask (#10) Fix list. --- lmdeploy/pytorch/models/cogvlm.py | 321 +++++++++++++++++++++++++- lmdeploy/pytorch/models/module_map.py | 4 +- 2 files changed, 320 insertions(+), 5 deletions(-) diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 45d7a4d01d..fa5ee5c7a8 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -3,19 +3,110 @@ import torch import torch.distributed as dist +import dlinfer.ops as ext_ops from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast +from lmdeploy.pytorch.kernels.ascend.fused_rotary_emb import fused_rotary_emb +from lmdeploy.pytorch.kernels.ascend.paged_attention_fwd import paged_attention_fwd -from ..kernels import fill_kv_cache, fused_rotary_emb, paged_attention_fwd +from ..kernels import fill_kv_cache from ..weight_loader.dist_utils import (colwise_split_parallelize_linear, rowwise_parallelize_linear) LANGUAGE_TOKEN_TYPE = 0 VISION_TOKEN_TYPE = 1 + # flake8: noqa: F821 +def get_range_ones(mask): + mask = mask.tolist()[0] + res_range = [] + inv_range = [] + count = 0 + inv_start = 0 + prev_diff = -1 + + # get range for continous ones + for i in range(len(mask)): + count += mask[i] + + # handling ones + # insert range at the end + if i > 0 and mask[i] == 0 and mask[i-1] == 1: + end = i - 1 + res_range.append([start, end]) + # handling zero + # insert inv_range at the end + if i > 0 and mask[i] == 1 and mask[i-1] == 0: + inv_end = i - 1 + inv_range.append([inv_start, inv_end]) + # prepare for next range + if i - count != prev_diff: + start = i + 1 + else: + inv_start = i + 1 + prev_diff = i - count + + # last range block + if mask[-1] == 1: + res_range.append([start, len(mask) - 1]) + else: + inv_range.append([inv_start, len(mask) - 1]) + return res_range, inv_range + + +def merge_section_size(left_range, right_range): + def get_len(range_elem): + return range_elem[1] - range_elem[0] + 1 + + res_size = [] + side_idx = [] + l_idx = r_idx = 0 + + # merge sort for l&r array + while l_idx < len(left_range) and r_idx < len(right_range): + if left_range[l_idx] < right_range[r_idx]: + l_len = get_len(left_range[l_idx]) + res_size.append(l_len) + side_idx.append(0) + l_idx += 1 + else: + r_len = get_len(right_range[r_idx]) + res_size.append(r_len) + side_idx.append(1) + r_idx += 1 + + # handle tailing data + if l_idx < len(left_range): + res_size.extend([get_len(elem) for elem in left_range[l_idx:]]) + side_idx.append(0) + if r_idx < len(right_range): + res_size.extend([get_len(elem) for elem in right_range[r_idx:]]) + side_idx.append(1) + return res_size, side_idx + + +def handle_mask_range_split(in_data, lang_fn, vision_fn, context, stage_info=None): + # split inputs for continous slice batch + all_mask_size, side_idx = merge_section_size( + context.vision_token_range, context.language_token_range) + split_hidden_states = torch.split(in_data, all_mask_size, dim=1) + + # calculate and merge + output_layer = [] + for i, elem in enumerate(split_hidden_states): + # language part + if side_idx[i] == 1: + output_layer.append(lang_fn(elem)) + # vision part + else: + output_layer.append(vision_fn(elem)) + output_layer = torch.cat(output_layer, dim=1) + return output_layer + + def get_vision_expert_mask( token_type_ids: 'torch.LongTensor(B, L)' ) -> '[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]': @@ -54,6 +145,32 @@ def forward(self, hidden_states: 'torch.Tensor(B, L, D)', return output +class PatchedVisionExpertMLPAscend(nn.Module): + + def forward(self, hidden_states: 'torch.Tensor(B, L, D)', + token_type_ids: 'torch.LongTensor(B, L)'): + context = self.context.context + only_has_language = context.is_decoding + if not context.is_decoding: + # for embedding splitting + if hasattr(context, 'vision_token_mask') and hasattr( + context, 'language_token_mask'): + vision_token_mask = context.vision_token_mask + language_token_mask = context.language_token_mask + only_has_language = vision_token_mask.numel() == 0 + else: + only_has_language = True + + if only_has_language: + output = self.language_mlp(hidden_states) + else: + output = handle_mask_range_split(hidden_states, + self.language_mlp, + self.vision_mlp, + self.context.context) + return output + + class PatchedVisionExpertAttention(nn.Module): def _load_weights(self, loader, rank: int, world_size: int, @@ -107,6 +224,7 @@ def _contiguous_batching_forward_impl( q_seq_length = context.q_seq_length kv_seq_length = context.kv_seq_length block_offsets = context.block_offsets + q_seq_length_list = context.q_seq_length_list max_q_seq_length = context.max_q_seq_length num_heads = self.config.num_attention_heads // world_size num_kv_heads = getattr(self.config, 'num_multi_query_heads', @@ -185,19 +303,24 @@ def __rotary_emb_fn(query_states, key_states, value_states): kv_seq_length=kv_seq_length, max_q_seq_length=max_q_seq_length, block_offsets=block_offsets, + context=self.context.context ) context_layer = query_states paged_attention_fwd( query_states, + key_states, + value_states, past_key_value[0], past_key_value[1], context_layer, block_offsets, q_start_loc=q_start_loc, q_seqlens=q_seq_length, + q_seqlens_list=q_seq_length_list, kv_seqlens=kv_seq_length, max_seqlen=max_q_seq_length, + context=self.context.context ) context_layer = context_layer.reshape(*hidden_states.shape[:-1], -1) @@ -238,6 +361,181 @@ def forward( ) +class PatchedVisionExpertAttentionAscend(nn.Module): + + def _load_weights(self, loader, rank: int, world_size: int, + device: torch.device): + """load weights.""" + num_heads = self.config.num_attention_heads + num_kv_heads = getattr(self.config, 'num_multi_query_heads', num_heads) + head_dim = self.config.hidden_size // num_heads + sections = [ + self.config.hidden_size, num_kv_heads * head_dim, + num_kv_heads * head_dim + ] + for name in [ + 'vision_expert_query_key_value', + 'language_expert_query_key_value' + ]: + colwise_split_parallelize_linear(getattr(self, name), + sections, + loader, + rank=rank, + world_size=world_size, + prefix=name) + for name in ['vision_expert_dense', 'language_expert_dense']: + rowwise_parallelize_linear(getattr(self, name), + loader, + rank=rank, + world_size=world_size, + prefix=name) + + @classmethod + def _distribute_output_fn(cls, outputs, **kwargs): + """Distribution output hook.""" + dist.all_reduce(outputs[0]) + return outputs + + def _contiguous_batching_forward_impl( + self, + hidden_states: torch.Tensor, + token_type_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + world_size: int = 1, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Rewrite implementation of Attention.forward. + + Add continuous batching support. Add paged attention support. + """ + context = self.context.context + q_start_loc = context.q_start_loc + q_seq_length = context.q_seq_length + kv_seq_length = context.kv_seq_length + block_offsets = context.block_offsets + q_seq_length_list = context.q_seq_length_list + max_q_seq_length = context.max_q_seq_length + num_heads = self.config.num_attention_heads // world_size + num_kv_heads = getattr(self.config, 'num_multi_query_heads', + self.config.num_attention_heads) // world_size + + head_dim = self.config.hidden_size // self.config.num_attention_heads + hidden_size = num_heads * head_dim + only_has_language = context.is_decoding + if not context.is_decoding: + # for embedding splitting + if hasattr(context, 'vision_token_mask') and hasattr( + context, 'language_token_mask'): + vision_token_mask = context.vision_token_mask + language_token_mask = context.language_token_mask + only_has_language = vision_token_mask.numel() == 0 + else: + only_has_language = True + + def __qkv_proj(hidden_states): + """qkv_proj.""" + if only_has_language: + mixed_raw_layer = self.language_expert_query_key_value( + hidden_states) + else: + mixed_raw_layer = handle_mask_range_split(hidden_states, + self.language_expert_query_key_value, + self.vision_expert_query_key_value, + self.context.context) + + query_states, key_states, value_states = torch.split( + mixed_raw_layer, [ + hidden_size, head_dim * num_kv_heads, + head_dim * num_kv_heads + ], + dim=-1) + return query_states, key_states, value_states + + def __rotary_emb_fn(query_states, key_states, value_states): + """rotary embedding func.""" + scaling_factor = getattr(self.rotary_emb, 'scaling_factor', 1.0) + inv_freq = self.rotary_emb.inv_freq + + q = query_states[None] + k = key_states[None] + batch, seqlen, _, _ = q.shape + + pos_id = position_ids[None].squeeze(0).unsqueeze(-1) + pos_freq = pos_id / scaling_factor * inv_freq + cos = (torch.cos(pos_freq).view(batch, seqlen, 1, + -1).repeat(1, 1, 1, + 2).to(q.dtype)) + sin = (torch.sin(pos_freq).view(batch, seqlen, 1, + -1).repeat(1, 1, 1, + 2).to(q.dtype)) + ext_ops.apply_rotary_pos_emb(q, k, + cos, sin, None, None) + return q[0], k[0], value_states + + query_states, key_states, value_states = __qkv_proj(hidden_states) + + query_states = query_states.view(-1, num_heads, head_dim) + key_states = key_states.view(-1, num_kv_heads, head_dim) + value_states = value_states.view(-1, num_kv_heads, head_dim) + + query_states, key_states, value_states = __rotary_emb_fn( + query_states, key_states, value_states) + + ext_ops.fill_kv_cache( + key_states, + value_states, + past_key_value[0], + past_key_value[1], + self.context.context.kv_start_indices + ) + + context_layer = query_states + paged_attention_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + context_layer, + block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seq_length, + kv_seqlens=kv_seq_length, + max_seqlen=max_q_seq_length, + context=self.context.context + ) + context_layer = context_layer.reshape(*hidden_states.shape[:-1], -1) + + if only_has_language: + attn_output = self.language_expert_dense(context_layer) + else: + attn_output = handle_mask_range_split(context_layer, + self.language_expert_dense, + self.vision_expert_dense, + self.context.context) + return attn_output, None, past_key_value + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Rewrite of forward.""" + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() + return self._contiguous_batching_forward_impl( + hidden_states, + position_ids=position_ids, + past_key_value=past_key_value, + world_size=world_size, + ) + + class PatchedCogVLMModel(nn.Module): def forward( @@ -258,7 +556,12 @@ def forward( if vision_embeddings is not None and len(vision_embeddings) > 0: # multi-modality - inputs_embeds[:, + if len(context.token_type_range) == 1: + inputs_embeds[:, + context.token_type_range[0][0] : context.token_type_range[0][1] + 1, :] = vision_embeddings.to( + inputs_embeds) + else: + inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to( inputs_embeds) hidden_states = inputs_embeds @@ -308,6 +611,9 @@ def _get_cogvlm_position_ids(context): """get cogvlm position_ids.""" inputs = context.inputs q_seq_length = inputs.seq_length + + # avoid duplicated seq_len tolist + context.q_seq_length_list = q_seq_length.tolist() vision_input_info = inputs.vision_inputs position_id_offsets = vision_input_info.history_image_token_lengths - vision_input_info.history_image_nums * 3 if inputs.is_decoding: @@ -318,7 +624,10 @@ def _get_cogvlm_position_ids(context): starts = inputs.history_lengths - vision_input_info.history_lengths ends = starts + q_seq_length token_type_ids = vision_input_info.input_embedding_indexing.to( - torch.int) + torch.int) + + # add token_type_range data_struct to context + context.token_type_range, _ = get_range_ones(vision_input_info.input_embedding_indexing) history_position_lengths = vision_input_info.history_lengths - position_id_offsets position_ids_all = history_position_lengths[:, None] + build_position_ids( @@ -339,6 +648,12 @@ def _get_cogvlm_position_ids(context): context.vision_token_mask = vision_token_mask_new context.language_token_mask = language_token_mask_new + + # add vision & lang token range to context + # vis_tok_mask_all is the original mask + context.vision_token_range, context.language_token_range = get_range_ones( + vision_token_mask_all) + else: position_ids = context.attention_mask.long().cumsum(-1) - 1 position_ids += (inputs.history_lengths - diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index e0f49715b6..ddba39fdb1 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -314,9 +314,9 @@ 'modeling_cogvlm.MLP': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', 'modeling_cogvlm.VisionExpertMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertMLP', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertMLPAscend', 'modeling_cogvlm.VisionExpertAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertAttention', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertAttentionAscend', 'modeling_cogvlm.CogVLMModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedCogVLMModel', }) From ae0e1ff0899edcdf94a67bfdf5fd3dc218e6845d Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Mon, 26 Aug 2024 02:28:20 +0000 Subject: [PATCH 2/8] Modify module_map implementation. --- lmdeploy/pytorch/models/module_map.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index ddba39fdb1..1dd738c4d4 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -314,9 +314,9 @@ 'modeling_cogvlm.MLP': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', 'modeling_cogvlm.VisionExpertMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertMLPAscend', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertMLP', 'modeling_cogvlm.VisionExpertAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertAttentionAscend', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertAttention', 'modeling_cogvlm.CogVLMModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedCogVLMModel', }) @@ -391,6 +391,14 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2AttentionAscend', }) +# ascend cogvlm +ASCEND_MODULE_MAP.update({ + 'modeling_cogvlm.VisionExpertMLP': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertMLPAscend', + 'modeling_cogvlm.VisionExpertAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertAttentionAscend', +}) + # ascend mixtral ASCEND_MODULE_MAP.update({ 'transformers.models.mixtral.modeling_mixtral.MixtralAttention': From ee689e59d38b81fda2dfd927ebce55506a18dd4d Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Mon, 26 Aug 2024 02:35:21 +0000 Subject: [PATCH 3/8] Remove context to_list. --- lmdeploy/pytorch/models/cogvlm.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index fa5ee5c7a8..5b1c9bdcd7 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -224,7 +224,6 @@ def _contiguous_batching_forward_impl( q_seq_length = context.q_seq_length kv_seq_length = context.kv_seq_length block_offsets = context.block_offsets - q_seq_length_list = context.q_seq_length_list max_q_seq_length = context.max_q_seq_length num_heads = self.config.num_attention_heads // world_size num_kv_heads = getattr(self.config, 'num_multi_query_heads', @@ -317,7 +316,6 @@ def __rotary_emb_fn(query_states, key_states, value_states): block_offsets, q_start_loc=q_start_loc, q_seqlens=q_seq_length, - q_seqlens_list=q_seq_length_list, kv_seqlens=kv_seq_length, max_seqlen=max_q_seq_length, context=self.context.context @@ -414,7 +412,6 @@ def _contiguous_batching_forward_impl( q_seq_length = context.q_seq_length kv_seq_length = context.kv_seq_length block_offsets = context.block_offsets - q_seq_length_list = context.q_seq_length_list max_q_seq_length = context.max_q_seq_length num_heads = self.config.num_attention_heads // world_size num_kv_heads = getattr(self.config, 'num_multi_query_heads', @@ -611,9 +608,6 @@ def _get_cogvlm_position_ids(context): """get cogvlm position_ids.""" inputs = context.inputs q_seq_length = inputs.seq_length - - # avoid duplicated seq_len tolist - context.q_seq_length_list = q_seq_length.tolist() vision_input_info = inputs.vision_inputs position_id_offsets = vision_input_info.history_image_token_lengths - vision_input_info.history_image_nums * 3 if inputs.is_decoding: From fb2aa10fcb6183631a4b821fe0079c7f2172d098 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Tue, 27 Aug 2024 02:33:45 +0000 Subject: [PATCH 4/8] Simplify code logic. --- lmdeploy/pytorch/models/cogvlm.py | 74 ++----------------------------- 1 file changed, 4 insertions(+), 70 deletions(-) diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 5b1c9bdcd7..2d63e4909a 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -8,15 +8,14 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from lmdeploy.pytorch.kernels.ascend.fused_rotary_emb import fused_rotary_emb from lmdeploy.pytorch.kernels.ascend.paged_attention_fwd import paged_attention_fwd +from lmdeploy.pytorch.kernels.ascend.fill_kv_cache import fill_kv_cache -from ..kernels import fill_kv_cache from ..weight_loader.dist_utils import (colwise_split_parallelize_linear, rowwise_parallelize_linear) LANGUAGE_TOKEN_TYPE = 0 VISION_TOKEN_TYPE = 1 - # flake8: noqa: F821 @@ -173,39 +172,6 @@ def forward(self, hidden_states: 'torch.Tensor(B, L, D)', class PatchedVisionExpertAttention(nn.Module): - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - num_heads = self.config.num_attention_heads - num_kv_heads = getattr(self.config, 'num_multi_query_heads', num_heads) - head_dim = self.config.hidden_size // num_heads - sections = [ - self.config.hidden_size, num_kv_heads * head_dim, - num_kv_heads * head_dim - ] - for name in [ - 'vision_expert_query_key_value', - 'language_expert_query_key_value' - ]: - colwise_split_parallelize_linear(getattr(self, name), - sections, - loader, - rank=rank, - world_size=world_size, - prefix=name) - for name in ['vision_expert_dense', 'language_expert_dense']: - rowwise_parallelize_linear(getattr(self, name), - loader, - rank=rank, - world_size=world_size, - prefix=name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - def _contiguous_batching_forward_impl( self, hidden_states: torch.Tensor, @@ -361,39 +327,6 @@ def forward( class PatchedVisionExpertAttentionAscend(nn.Module): - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - num_heads = self.config.num_attention_heads - num_kv_heads = getattr(self.config, 'num_multi_query_heads', num_heads) - head_dim = self.config.hidden_size // num_heads - sections = [ - self.config.hidden_size, num_kv_heads * head_dim, - num_kv_heads * head_dim - ] - for name in [ - 'vision_expert_query_key_value', - 'language_expert_query_key_value' - ]: - colwise_split_parallelize_linear(getattr(self, name), - sections, - loader, - rank=rank, - world_size=world_size, - prefix=name) - for name in ['vision_expert_dense', 'language_expert_dense']: - rowwise_parallelize_linear(getattr(self, name), - loader, - rank=rank, - world_size=world_size, - prefix=name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - def _contiguous_batching_forward_impl( self, hidden_states: torch.Tensor, @@ -479,12 +412,13 @@ def __rotary_emb_fn(query_states, key_states, value_states): query_states, key_states, value_states = __rotary_emb_fn( query_states, key_states, value_states) - ext_ops.fill_kv_cache( + fill_kv_cache( key_states, value_states, past_key_value[0], past_key_value[1], - self.context.context.kv_start_indices + None, None, None, None, None, + self.context.context ) context_layer = query_states From ee44f4d87cbbd50c9cfa6789d96846e40c44eb35 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Tue, 27 Aug 2024 08:56:16 +0000 Subject: [PATCH 5/8] Remove index op replacement. --- lmdeploy/pytorch/models/cogvlm.py | 234 +++++++++----------------- lmdeploy/pytorch/models/module_map.py | 2 - 2 files changed, 78 insertions(+), 158 deletions(-) diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 2d63e4909a..1178109e70 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -9,6 +9,7 @@ from lmdeploy.pytorch.kernels.ascend.fused_rotary_emb import fused_rotary_emb from lmdeploy.pytorch.kernels.ascend.paged_attention_fwd import paged_attention_fwd from lmdeploy.pytorch.kernels.ascend.fill_kv_cache import fill_kv_cache +from lmdeploy.pytorch.kernels.ascend.fused_rotary_emb import fused_rotary_emb from ..weight_loader.dist_utils import (colwise_split_parallelize_linear, rowwise_parallelize_linear) @@ -19,93 +20,6 @@ # flake8: noqa: F821 -def get_range_ones(mask): - mask = mask.tolist()[0] - res_range = [] - inv_range = [] - count = 0 - inv_start = 0 - prev_diff = -1 - - # get range for continous ones - for i in range(len(mask)): - count += mask[i] - - # handling ones - # insert range at the end - if i > 0 and mask[i] == 0 and mask[i-1] == 1: - end = i - 1 - res_range.append([start, end]) - # handling zero - # insert inv_range at the end - if i > 0 and mask[i] == 1 and mask[i-1] == 0: - inv_end = i - 1 - inv_range.append([inv_start, inv_end]) - # prepare for next range - if i - count != prev_diff: - start = i + 1 - else: - inv_start = i + 1 - prev_diff = i - count - - # last range block - if mask[-1] == 1: - res_range.append([start, len(mask) - 1]) - else: - inv_range.append([inv_start, len(mask) - 1]) - return res_range, inv_range - - -def merge_section_size(left_range, right_range): - def get_len(range_elem): - return range_elem[1] - range_elem[0] + 1 - - res_size = [] - side_idx = [] - l_idx = r_idx = 0 - - # merge sort for l&r array - while l_idx < len(left_range) and r_idx < len(right_range): - if left_range[l_idx] < right_range[r_idx]: - l_len = get_len(left_range[l_idx]) - res_size.append(l_len) - side_idx.append(0) - l_idx += 1 - else: - r_len = get_len(right_range[r_idx]) - res_size.append(r_len) - side_idx.append(1) - r_idx += 1 - - # handle tailing data - if l_idx < len(left_range): - res_size.extend([get_len(elem) for elem in left_range[l_idx:]]) - side_idx.append(0) - if r_idx < len(right_range): - res_size.extend([get_len(elem) for elem in right_range[r_idx:]]) - side_idx.append(1) - return res_size, side_idx - - -def handle_mask_range_split(in_data, lang_fn, vision_fn, context, stage_info=None): - # split inputs for continous slice batch - all_mask_size, side_idx = merge_section_size( - context.vision_token_range, context.language_token_range) - split_hidden_states = torch.split(in_data, all_mask_size, dim=1) - - # calculate and merge - output_layer = [] - for i, elem in enumerate(split_hidden_states): - # language part - if side_idx[i] == 1: - output_layer.append(lang_fn(elem)) - # vision part - else: - output_layer.append(vision_fn(elem)) - output_layer = torch.cat(output_layer, dim=1) - return output_layer - - def get_vision_expert_mask( token_type_ids: 'torch.LongTensor(B, L)' ) -> '[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]': @@ -144,34 +58,41 @@ def forward(self, hidden_states: 'torch.Tensor(B, L, D)', return output -class PatchedVisionExpertMLPAscend(nn.Module): - - def forward(self, hidden_states: 'torch.Tensor(B, L, D)', - token_type_ids: 'torch.LongTensor(B, L)'): - context = self.context.context - only_has_language = context.is_decoding - if not context.is_decoding: - # for embedding splitting - if hasattr(context, 'vision_token_mask') and hasattr( - context, 'language_token_mask'): - vision_token_mask = context.vision_token_mask - language_token_mask = context.language_token_mask - only_has_language = vision_token_mask.numel() == 0 - else: - only_has_language = True - - if only_has_language: - output = self.language_mlp(hidden_states) - else: - output = handle_mask_range_split(hidden_states, - self.language_mlp, - self.vision_mlp, - self.context.context) - return output - - class PatchedVisionExpertAttention(nn.Module): + def _load_weights(self, loader, rank: int, world_size: int, + device: torch.device): + """load weights.""" + num_heads = self.config.num_attention_heads + num_kv_heads = getattr(self.config, 'num_multi_query_heads', num_heads) + head_dim = self.config.hidden_size // num_heads + sections = [ + self.config.hidden_size, num_kv_heads * head_dim, + num_kv_heads * head_dim + ] + for name in [ + 'vision_expert_query_key_value', + 'language_expert_query_key_value' + ]: + colwise_split_parallelize_linear(getattr(self, name), + sections, + loader, + rank=rank, + world_size=world_size, + prefix=name) + for name in ['vision_expert_dense', 'language_expert_dense']: + rowwise_parallelize_linear(getattr(self, name), + loader, + rank=rank, + world_size=world_size, + prefix=name) + + @classmethod + def _distribute_output_fn(cls, outputs, **kwargs): + """Distribution output hook.""" + dist.all_reduce(outputs[0]) + return outputs + def _contiguous_batching_forward_impl( self, hidden_states: torch.Tensor, @@ -369,11 +290,18 @@ def __qkv_proj(hidden_states): mixed_raw_layer = self.language_expert_query_key_value( hidden_states) else: - mixed_raw_layer = handle_mask_range_split(hidden_states, - self.language_expert_query_key_value, - self.vision_expert_query_key_value, - self.context.context) + shape = list(hidden_states.shape) + shape[-1] = hidden_size + head_dim * num_kv_heads * 2 + mixed_raw_layer = torch.empty(shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + mixed_raw_layer[:, + vision_token_mask, :] = self.vision_expert_query_key_value( + hidden_states[:, vision_token_mask, :]) + mixed_raw_layer[:, + language_token_mask, :] = self.language_expert_query_key_value( + hidden_states[:, language_token_mask, :]) query_states, key_states, value_states = torch.split( mixed_raw_layer, [ hidden_size, head_dim * num_kv_heads, @@ -387,21 +315,16 @@ def __rotary_emb_fn(query_states, key_states, value_states): scaling_factor = getattr(self.rotary_emb, 'scaling_factor', 1.0) inv_freq = self.rotary_emb.inv_freq - q = query_states[None] - k = key_states[None] - batch, seqlen, _, _ = q.shape - - pos_id = position_ids[None].squeeze(0).unsqueeze(-1) - pos_freq = pos_id / scaling_factor * inv_freq - cos = (torch.cos(pos_freq).view(batch, seqlen, 1, - -1).repeat(1, 1, 1, - 2).to(q.dtype)) - sin = (torch.sin(pos_freq).view(batch, seqlen, 1, - -1).repeat(1, 1, 1, - 2).to(q.dtype)) - ext_ops.apply_rotary_pos_emb(q, k, - cos, sin, None, None) - return q[0], k[0], value_states + query_states, key_states = fused_rotary_emb( + query_states[None], + key_states[None], + position_ids[None], + inv_freq=inv_freq, + scaling_factor=scaling_factor, + out_q=query_states[None], + out_k=key_states[None], + context=context) + return query_states[0], key_states[0], value_states query_states, key_states, value_states = __qkv_proj(hidden_states) @@ -417,8 +340,12 @@ def __rotary_emb_fn(query_states, key_states, value_states): value_states, past_key_value[0], past_key_value[1], - None, None, None, None, None, - self.context.context + q_start_loc, + q_seq_length, + kv_seq_length=kv_seq_length, + max_q_seq_length=max_q_seq_length, + block_offsets=block_offsets, + context=context ) context_layer = query_states @@ -434,17 +361,25 @@ def __rotary_emb_fn(query_states, key_states, value_states): q_seqlens=q_seq_length, kv_seqlens=kv_seq_length, max_seqlen=max_q_seq_length, - context=self.context.context + context=context ) context_layer = context_layer.reshape(*hidden_states.shape[:-1], -1) if only_has_language: attn_output = self.language_expert_dense(context_layer) else: - attn_output = handle_mask_range_split(context_layer, - self.language_expert_dense, - self.vision_expert_dense, - self.context.context) + ctx_shape = list(context_layer.shape) + ctx_shape[-1] *= world_size + attn_output = torch.empty(ctx_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + + attn_output[:, vision_token_mask, :] = self.vision_expert_dense( + context_layer[:, vision_token_mask, :]) + attn_output[:, + language_token_mask, :] = self.language_expert_dense( + context_layer[:, language_token_mask, :]) + return attn_output, None, past_key_value def forward( @@ -487,14 +422,9 @@ def forward( if vision_embeddings is not None and len(vision_embeddings) > 0: # multi-modality - if len(context.token_type_range) == 1: - inputs_embeds[:, - context.token_type_range[0][0] : context.token_type_range[0][1] + 1, :] = vision_embeddings.to( - inputs_embeds) - else: - inputs_embeds[:, - vision_embedding_indexing, :] = vision_embeddings.to( - inputs_embeds) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) hidden_states = inputs_embeds for idx, decoder_layer in enumerate(self.layers): @@ -552,10 +482,8 @@ def _get_cogvlm_position_ids(context): starts = inputs.history_lengths - vision_input_info.history_lengths ends = starts + q_seq_length token_type_ids = vision_input_info.input_embedding_indexing.to( - torch.int) + torch.int) - # add token_type_range data_struct to context - context.token_type_range, _ = get_range_ones(vision_input_info.input_embedding_indexing) history_position_lengths = vision_input_info.history_lengths - position_id_offsets position_ids_all = history_position_lengths[:, None] + build_position_ids( @@ -576,12 +504,6 @@ def _get_cogvlm_position_ids(context): context.vision_token_mask = vision_token_mask_new context.language_token_mask = language_token_mask_new - - # add vision & lang token range to context - # vis_tok_mask_all is the original mask - context.vision_token_range, context.language_token_range = get_range_ones( - vision_token_mask_all) - else: position_ids = context.attention_mask.long().cumsum(-1) - 1 position_ids += (inputs.history_lengths - diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 1dd738c4d4..82ea59b600 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -393,8 +393,6 @@ # ascend cogvlm ASCEND_MODULE_MAP.update({ - 'modeling_cogvlm.VisionExpertMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertMLPAscend', 'modeling_cogvlm.VisionExpertAttention': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertAttentionAscend', }) From 9b0841238e5e4dbe7039a5f41d9ddf3bd67711f9 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Tue, 27 Aug 2024 09:05:09 +0000 Subject: [PATCH 6/8] Fix some code-style issue. --- lmdeploy/pytorch/models/cogvlm.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 1178109e70..5dca4bfdc8 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -3,7 +3,6 @@ import torch import torch.distributed as dist -import dlinfer.ops as ext_ops from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast from lmdeploy.pytorch.kernels.ascend.fused_rotary_emb import fused_rotary_emb @@ -189,14 +188,11 @@ def __rotary_emb_fn(query_states, key_states, value_states): kv_seq_length=kv_seq_length, max_q_seq_length=max_q_seq_length, block_offsets=block_offsets, - context=self.context.context ) context_layer = query_states paged_attention_fwd( query_states, - key_states, - value_states, past_key_value[0], past_key_value[1], context_layer, @@ -205,7 +201,6 @@ def __rotary_emb_fn(query_states, key_states, value_states): q_seqlens=q_seq_length, kv_seqlens=kv_seq_length, max_seqlen=max_q_seq_length, - context=self.context.context ) context_layer = context_layer.reshape(*hidden_states.shape[:-1], -1) @@ -423,8 +418,8 @@ def forward( if vision_embeddings is not None and len(vision_embeddings) > 0: # multi-modality inputs_embeds[:, - vision_embedding_indexing, :] = vision_embeddings.to( - inputs_embeds) + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) hidden_states = inputs_embeds for idx, decoder_layer in enumerate(self.layers): @@ -483,7 +478,6 @@ def _get_cogvlm_position_ids(context): ends = starts + q_seq_length token_type_ids = vision_input_info.input_embedding_indexing.to( torch.int) - history_position_lengths = vision_input_info.history_lengths - position_id_offsets position_ids_all = history_position_lengths[:, None] + build_position_ids( From 5d0cafe0cf1b3a309a49fab9ab3328803d0e4c85 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Tue, 27 Aug 2024 09:11:09 +0000 Subject: [PATCH 7/8] Split path between ascend and other platforms. --- lmdeploy/pytorch/models/cogvlm.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 5dca4bfdc8..43fe5f1788 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -5,11 +5,11 @@ import torch.distributed as dist from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast -from lmdeploy.pytorch.kernels.ascend.fused_rotary_emb import fused_rotary_emb -from lmdeploy.pytorch.kernels.ascend.paged_attention_fwd import paged_attention_fwd -from lmdeploy.pytorch.kernels.ascend.fill_kv_cache import fill_kv_cache -from lmdeploy.pytorch.kernels.ascend.fused_rotary_emb import fused_rotary_emb +from lmdeploy.pytorch.kernels.ascend.fused_rotary_emb import fused_rotary_emb_ascend +from lmdeploy.pytorch.kernels.ascend.paged_attention_fwd import paged_attention_fwd_ascend +from lmdeploy.pytorch.kernels.ascend.fill_kv_cache import fill_kv_cache_ascend +from ..kernels import fill_kv_cache, fused_rotary_emb, paged_attention_fwd from ..weight_loader.dist_utils import (colwise_split_parallelize_linear, rowwise_parallelize_linear) @@ -310,7 +310,7 @@ def __rotary_emb_fn(query_states, key_states, value_states): scaling_factor = getattr(self.rotary_emb, 'scaling_factor', 1.0) inv_freq = self.rotary_emb.inv_freq - query_states, key_states = fused_rotary_emb( + query_states, key_states = fused_rotary_emb_ascend( query_states[None], key_states[None], position_ids[None], @@ -330,7 +330,7 @@ def __rotary_emb_fn(query_states, key_states, value_states): query_states, key_states, value_states = __rotary_emb_fn( query_states, key_states, value_states) - fill_kv_cache( + fill_kv_cache_ascend( key_states, value_states, past_key_value[0], @@ -344,7 +344,7 @@ def __rotary_emb_fn(query_states, key_states, value_states): ) context_layer = query_states - paged_attention_fwd( + paged_attention_fwd_ascend( query_states, key_states, value_states, From cc4c7240019018fbf3190084b1c4858067c41f19 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Tue, 27 Aug 2024 09:12:16 +0000 Subject: [PATCH 8/8] Rename import module name. --- lmdeploy/pytorch/models/cogvlm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 43fe5f1788..8a522f0238 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -5,9 +5,9 @@ import torch.distributed as dist from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast -from lmdeploy.pytorch.kernels.ascend.fused_rotary_emb import fused_rotary_emb_ascend -from lmdeploy.pytorch.kernels.ascend.paged_attention_fwd import paged_attention_fwd_ascend -from lmdeploy.pytorch.kernels.ascend.fill_kv_cache import fill_kv_cache_ascend +from lmdeploy.pytorch.kernels.ascend.fused_rotary_emb import fused_rotary_emb as fused_rotary_emb_ascend +from lmdeploy.pytorch.kernels.ascend.paged_attention_fwd import paged_attention_fwd as paged_attention_fwd_ascend +from lmdeploy.pytorch.kernels.ascend.fill_kv_cache import fill_kv_cache as fill_kv_cache_ascend from ..kernels import fill_kv_cache, fused_rotary_emb, paged_attention_fwd from ..weight_loader.dist_utils import (colwise_split_parallelize_linear,