diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 1e85cccb3c5b..3c759213df7b 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -34,6 +34,7 @@ def __init__( model_config: Dict[str, Any], plugin_config: Dict[str, Any], microbatch_size: int = 1, + pp_batch_size: int = 8, save_interval: int = 100, save_dir: str = "./model", ): @@ -54,7 +55,13 @@ def __init__( self.model_config = model_config self.plugin_config = plugin_config - assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" + + # To support pipeline parallel, + # we use (train) microbatch_size as pp batch size. + # So, the pp microbatch size = microbatch_size// pp size + self.pp_microbatch_size = pp_batch_size // self.plugin_config.get("pp_size", 1) + self.pp_num_microbatches = pp_batch_size // self.pp_microbatch_size + # assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" self.device = get_current_device() @@ -66,13 +73,18 @@ def setup(self) -> None: launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) plugin_config = dict( - tp_size=1, - pp_size=1, + tp_size=self.plugin_config.get("tp_size", 1), + pp_size=self.plugin_config.get("pp_size", 1), + # microbatch_size=self.pp_microbatch_size, + num_microbatches=self.pp_num_microbatches, precision="bf16", zero_stage=1, + enable_flash_attention=True, ) - if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: - plugin_config["microbatch_size"] = self.microbatch_size + + # if plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: + # # plugin_config["microbatch_size"] = self.microbatch_size + # plugin_config["num_microbatches"] = plugin_config.get("pp_size", 1) plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) @@ -99,7 +111,6 @@ def loop(self) -> None: i = 0 for _ in range(self.num_recv_per_update): # receive data from producers - for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") self.buffer.extend( diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index b1edb89bb0e5..3a5c47877ec1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -3,15 +3,18 @@ import ray import torch +import torch.distributed as dist import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward -from coati.distributed.utils import calc_action_log_probs +from coati.distributed.utils import calc_action_log_probs, filter_microbatch_dicts, split_into_microbatches from coati.trainer.utils import all_reduce_mean +from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer +from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam @@ -31,6 +34,7 @@ def __init__( model_config, plugin_config, microbatch_size=1, + pp_batch_size=8, num_generations=4, use_wandb=True, ): @@ -47,6 +51,7 @@ def __init__( model_config, plugin_config, microbatch_size, + pp_batch_size, ) path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -86,10 +91,13 @@ def __init__( if use_wandb and self.rank == 0: self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True) + self.coordinator = None + def setup(self): super().setup() self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer) self.reference_model, *_ = self.booster.boost(self.reference_model) + self.coordinator = DistCoordinator() def step(self, step_idx: int, **kwargs) -> Optional[float]: """ @@ -106,31 +114,103 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: """ # Reshape to [batch_size x num_of_generation, prompt_length + response_length] - data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} - action_mask = data["action_mask"] - num_action = action_mask.shape[1] - old_action_log_probs = data["action_log_probs"] - response_length = torch.sum(action_mask, dim=1).to(torch.float32) need_update = (step_idx + 1) % self.num_microbatches == 0 - ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) + ctx = nullcontext() + # ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) with ctx: - policy_model_logits = self.policy_model( - input_ids=data["input_ids"], - attention_mask=data["attention_mask"], - )["logits"] + data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} + # print(f"Before split Rank {dist.get_rank()}] \ + # input_ids {data['input_ids'].shape} \ + # attention_mask {data['attention_mask'].shape} \ + # action_mask {data['action_mask'].shape} \ + # gt_answer {data['gt_answer'].shape}\ ") + + data_iter = split_into_microbatches(data, self.pp_microbatch_size) # self.pp_num_microbatches + + # print(f"After split Rank {dist.get_rank()}] \ + # input_ids {data_iter[0]['input_ids'].shape} \ + # attention_mask {data_iter[0]['attention_mask'].shape} \ + # action_mask {data_iter[0]['action_mask'].shape} \ + # gt_answer {data_iter[0]['gt_answer'].shape}\ ") + + input_ids = data["input_ids"] + attention_mask = data["attention_mask"] + action_mask = data["action_mask"] + num_action = action_mask.shape[1] + old_action_log_probs = data["action_log_probs"] + gt_answer = data["gt_answer"] + response_idx = data["response_idx"] + response_length = torch.sum(action_mask, dim=1).to(torch.float32) + + policy_model_logits = None + reference_model_logits = None + if self.booster.plugin.pp_size > 1: + # allowed_keys = ("input_ids", "attention_mask") + # data_iter = [{key: value for key, value in data.items() if key in allowed_keys}] + data_iter = filter_microbatch_dicts(data_iter) + # We don't have to iter data_iter, cause data_iter means a microbatch now. + step_bar = tqdm( + range(len(data_iter)), + desc="Step", + disable=not self.coordinator.rank == self.coordinator.world_size - 1, + ) + # You must init two data iter for policy model and inference model respectively. or you will get next(data_iter) out of idx. + data_iter, data_iter_infer = iter(data_iter), iter(data_iter) + for step in step_bar: + policy_model_output = self.booster.execute_pipeline( + data_iter, + self.policy_model, + criterion=lambda x, y: x.logits.mean(), + optimizer=self.optimizer, + return_loss=False, + return_outputs=True, + ) + + with torch.no_grad(): + reference_model_output = self.booster.execute_pipeline( + data_iter_infer, + self.reference_model, + criterion=lambda x, y: x.logits.mean(), + return_loss=False, + return_outputs=True, + ) + + if self.booster.plugin.stage_manager.is_last_stage(): + local_policy_model_logits = policy_model_output["outputs"]["logits"] + local_reference_model_logits = reference_model_output["outputs"]["logits"] + if step == 0: + policy_model_logits = local_policy_model_logits + reference_model_logits = local_reference_model_logits + else: + policy_model_logits = torch.cat((policy_model_logits, local_policy_model_logits), dim=0) + reference_model_logits = torch.cat( + (reference_model_logits, local_reference_model_logits), dim=0 + ) + if self.booster.plugin.stage_manager.is_last_stage(): + print( + f"Rank {dist.get_rank()} step {step} policy_model_logits {policy_model_logits.shape} {policy_model_logits} reference_model_logits {reference_model_logits.shape} {reference_model_logits}" + ) + + else: + policy_model_logits = self.policy_model( + input_ids=input_ids, + attention_mask=attention_mask, + )["logits"] + + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids, + attention_mask=attention_mask, + )["logits"] + action_log_probs = calc_action_log_probs( - policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config + policy_model_logits, input_ids, num_action, self.plugin.shard_config ) - with torch.no_grad(): - reference_model_logits = self.reference_model( - input_ids=data["input_ids"], - attention_mask=data["attention_mask"], - )["logits"] reference_action_log_probs = calc_action_log_probs( - reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config + reference_model_logits, input_ids, num_action, self.plugin.shard_config ) per_token_kl = ( @@ -140,13 +220,11 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ) kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1) - reward_group = self.reward_model( - data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] - ) + reward_group = self.reward_model(input_ids, gt_answer=gt_answer, response_idx=response_idx) - reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) - format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) - acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) + reward = torch.tensor([value[0] for value in reward_group]).to(input_ids.device) + format_reward = torch.tensor([value[1] for value in reward_group]).to(input_ids.device) + acc_reward = torch.tensor([value[2] for value in reward_group]).to(input_ids.device) # [batch_size, num_generations] group_reward = reward.view(-1, self.num_generations) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 8581ff5865f8..b126a2672d0f 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -36,6 +36,7 @@ def launch_distributed( inference_microbatch_size: int, train_batch_size: int, train_microbatch_size: int, + pp_batch_size: int, dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], inference_model_config: Dict[str, Any], @@ -94,6 +95,7 @@ def launch_distributed( model_config=train_model_config, plugin_config=plugin_config, microbatch_size=train_microbatch_size, + pp_batch_size=pp_batch_size, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 919e4434faa6..5f879a8f3732 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -26,6 +26,51 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor return batch +def split_into_microbatches(data_dict, microbatch_size): + """ + 将包含多个张量的字典根据 microbatch_size 切分成多个微批次字典。 + :param data_dict: 包含多个张量的字典,input_ids 形状为 (batch_size, seq_len, hidden_dim) + :param microbatch_size: 每个微批次的大小 + :return: 微批次字典列表 + """ + batch_size = next(iter(data_dict.values())).size(0) + microbatch_dicts = [] + + for start_idx in range(0, batch_size, microbatch_size): + end_idx = min(start_idx + microbatch_size, batch_size) + microbatch_dict = {} + for key, tensor in data_dict.items(): + if tensor.size(0) == batch_size: + microbatch_dict[key] = tensor[start_idx:end_idx] + else: + microbatch_dict[key] = tensor + microbatch_dicts.append(microbatch_dict) + + return microbatch_dicts + + +def cyclic_iter(dataloader): + epoch = 0 + while True: + for batch in dataloader: + yield batch + epoch += 1 + + +def filter_microbatch_dicts(microbatch_dicts): + """ + 遍历 microbatch_dicts 列表,移除每个字典中键不在 ("input_ids", "attention_mask") 范围内的键值对 + :param microbatch_dicts: 包含多个字典的列表 + :return: 过滤后的 microbatch_dicts 列表 + """ + filtered_dicts = [] + allowed_keys = ("input_ids", "attention_mask") + for microbatch_dict in microbatch_dicts: + filtered_dict = {key: value for key, value in microbatch_dict.items() if key in allowed_keys} + filtered_dicts.append(filtered_dict) + return filtered_dicts + + def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # compress mask to save bandwidth if "attention_mask" in batch: diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 1de8b649d5d1..9d282fe4bf0b 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -14,6 +14,7 @@ parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) + parser.add_argument("-ppmbs", "--pp-batch-size", type=int, default=8) parser.add_argument("-b", "--backend", type=str, default="transformers") parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"]) args = parser.parse_args() @@ -31,13 +32,15 @@ if args.backend == "transformers": inference_model_config.update( dict( - use_flash_attention_2=True, + # use_flash_attention_2=True, + attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, ) ) train_model_config.update( dict( - use_flash_attention_2=True, + # use_flash_attention_2=True, + attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, use_cache=False, ) @@ -89,12 +92,13 @@ inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, train_microbatch_size=args.train_microbatch_size, + pp_batch_size=args.pp_batch_size, dataset_config={"path": args.dataset, "max_length": 300}, dataloaders_config={}, inference_model_config=inference_model_config, generate_config=generate_config, train_model_config=train_model_config, - plugin_config={}, + plugin_config={"tp_size": 1, "pp_size": 2}, inference_backend=args.backend, master_addr="localhost", master_port=29504, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1684fd702e70..f17ac94b4e30 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1416,10 +1416,12 @@ def execute_pipeline( ): return outputs - # Synchronize the grads of shared parameters of the model. - model.sync_shared_params() - # Synchronize sequence parallelism gradients of the model. - model.sync_sp_grads() + # Synchronize when training + if torch.is_grad_enabled(): + # Synchronize the grads of shared parameters of the model. + model.sync_shared_params() + # Synchronize sequence parallelism gradients of the model. + model.sync_sp_grads() # Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so. # Otherwise, synchronize data parallelism gradients of the model. diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 51419a38a0ed..a9bb76fc7d6b 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -387,7 +387,7 @@ def dist_log_prob( dtype=dtype, ) else: - log_prob = log_softmax(logits) + log_prob = log_softmax(logits, dim=-1) log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1)) return log_prob diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 71e3557fe214..abe1b85ac648 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -131,7 +131,7 @@ def qwen2_model_forward( else: position_ids = position_ids.view(-1, seq_length).long() - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -152,10 +152,10 @@ def qwen2_model_forward( is_causal=True, ) else: - if self._attn_implementation == "flash_attention_2": + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: + elif self.config._attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -534,7 +534,8 @@ def forward( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + # cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) # if transformer version <= 4.39.3 + cos, sin = self.rotary_emb(value_states, position_ids) # if transformer version > 4.39.3 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 0adcdfdbd553..fd14029a3a36 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -11,6 +11,7 @@ Linear1D_Row, LinearWithGradAccum, PaddingEmbedding, + PaddingLMHead, RMSNorm, VocabParallelEmbedding1D, VocabParallelLMHead1D, @@ -449,7 +450,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=LinearWithGradAccum, + target_module=PaddingLMHead, kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), ), SubModuleReplacementDescription(