Skip to content

[feat] support grpo pipeline; update qwen modeling for transformer 4.… #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: grpo-latest
Choose a base branch
from
23 changes: 17 additions & 6 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
microbatch_size: int = 1,
pp_batch_size: int = 8,
save_interval: int = 100,
save_dir: str = "./model",
):
Expand All @@ -54,7 +55,13 @@ def __init__(

self.model_config = model_config
self.plugin_config = plugin_config
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"

# To support pipeline parallel,
# we use (train) microbatch_size as pp batch size.
# So, the pp microbatch size = microbatch_size// pp size
self.pp_microbatch_size = pp_batch_size // self.plugin_config.get("pp_size", 1)
self.pp_num_microbatches = pp_batch_size // self.pp_microbatch_size
# assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"

self.device = get_current_device()

Expand All @@ -66,13 +73,18 @@ def setup(self) -> None:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)

plugin_config = dict(
tp_size=1,
pp_size=1,
tp_size=self.plugin_config.get("tp_size", 1),
pp_size=self.plugin_config.get("pp_size", 1),
# microbatch_size=self.pp_microbatch_size,
num_microbatches=self.pp_num_microbatches,
precision="bf16",
zero_stage=1,
enable_flash_attention=True,
)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size

# if plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
# # plugin_config["microbatch_size"] = self.microbatch_size
# plugin_config["num_microbatches"] = plugin_config.get("pp_size", 1)
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
Expand All @@ -99,7 +111,6 @@ def loop(self) -> None:
i = 0
for _ in range(self.num_recv_per_update):
# receive data from producers

for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend(
Expand Down
126 changes: 102 additions & 24 deletions applications/ColossalChat/coati/distributed/grpo_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@

import ray
import torch
import torch.distributed as dist
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.distributed.utils import calc_action_log_probs, filter_microbatch_dicts, split_into_microbatches
from coati.trainer.utils import all_reduce_mean
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam


Expand All @@ -31,6 +34,7 @@ def __init__(
model_config,
plugin_config,
microbatch_size=1,
pp_batch_size=8,
num_generations=4,
use_wandb=True,
):
Expand All @@ -47,6 +51,7 @@ def __init__(
model_config,
plugin_config,
microbatch_size,
pp_batch_size,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
Expand Down Expand Up @@ -86,10 +91,13 @@ def __init__(
if use_wandb and self.rank == 0:
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)

self.coordinator = None

def setup(self):
super().setup()
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
self.reference_model, *_ = self.booster.boost(self.reference_model)
self.coordinator = DistCoordinator()

def step(self, step_idx: int, **kwargs) -> Optional[float]:
"""
Expand All @@ -106,31 +114,103 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
"""

# Reshape to [batch_size x num_of_generation, prompt_length + response_length]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
action_mask = data["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)

need_update = (step_idx + 1) % self.num_microbatches == 0

ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
ctx = nullcontext()
# ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
with ctx:
policy_model_logits = self.policy_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
# print(f"Before split Rank {dist.get_rank()}] \
# input_ids {data['input_ids'].shape} \
# attention_mask {data['attention_mask'].shape} \
# action_mask {data['action_mask'].shape} \
# gt_answer {data['gt_answer'].shape}\ ")

data_iter = split_into_microbatches(data, self.pp_microbatch_size) # self.pp_num_microbatches

# print(f"After split Rank {dist.get_rank()}] \
# input_ids {data_iter[0]['input_ids'].shape} \
# attention_mask {data_iter[0]['attention_mask'].shape} \
# action_mask {data_iter[0]['action_mask'].shape} \
# gt_answer {data_iter[0]['gt_answer'].shape}\ ")

input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
action_mask = data["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
gt_answer = data["gt_answer"]
response_idx = data["response_idx"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)

policy_model_logits = None
reference_model_logits = None
if self.booster.plugin.pp_size > 1:
# allowed_keys = ("input_ids", "attention_mask")
# data_iter = [{key: value for key, value in data.items() if key in allowed_keys}]
data_iter = filter_microbatch_dicts(data_iter)
# We don't have to iter data_iter, cause data_iter means a microbatch now.
step_bar = tqdm(
range(len(data_iter)),
desc="Step",
disable=not self.coordinator.rank == self.coordinator.world_size - 1,
)
# You must init two data iter for policy model and inference model respectively. or you will get next(data_iter) out of idx.
data_iter, data_iter_infer = iter(data_iter), iter(data_iter)
for step in step_bar:
policy_model_output = self.booster.execute_pipeline(
data_iter,
self.policy_model,
criterion=lambda x, y: x.logits.mean(),
optimizer=self.optimizer,
return_loss=False,
return_outputs=True,
)

with torch.no_grad():
reference_model_output = self.booster.execute_pipeline(
data_iter_infer,
self.reference_model,
criterion=lambda x, y: x.logits.mean(),
return_loss=False,
return_outputs=True,
)

if self.booster.plugin.stage_manager.is_last_stage():
local_policy_model_logits = policy_model_output["outputs"]["logits"]
local_reference_model_logits = reference_model_output["outputs"]["logits"]
if step == 0:
policy_model_logits = local_policy_model_logits
reference_model_logits = local_reference_model_logits
else:
policy_model_logits = torch.cat((policy_model_logits, local_policy_model_logits), dim=0)
reference_model_logits = torch.cat(
(reference_model_logits, local_reference_model_logits), dim=0
)
if self.booster.plugin.stage_manager.is_last_stage():
print(
f"Rank {dist.get_rank()} step {step} policy_model_logits {policy_model_logits.shape} {policy_model_logits} reference_model_logits {reference_model_logits.shape} {reference_model_logits}"
)

else:
policy_model_logits = self.policy_model(
input_ids=input_ids,
attention_mask=attention_mask,
)["logits"]

with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=input_ids,
attention_mask=attention_mask,
)["logits"]

action_log_probs = calc_action_log_probs(
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
policy_model_logits, input_ids, num_action, self.plugin.shard_config
)

with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
reference_model_logits, input_ids, num_action, self.plugin.shard_config
)

per_token_kl = (
Expand All @@ -140,13 +220,11 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
)
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)

reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward_group = self.reward_model(input_ids, gt_answer=gt_answer, response_idx=response_idx)

reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
reward = torch.tensor([value[0] for value in reward_group]).to(input_ids.device)
format_reward = torch.tensor([value[1] for value in reward_group]).to(input_ids.device)
acc_reward = torch.tensor([value[2] for value in reward_group]).to(input_ids.device)

# [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
Expand Down
2 changes: 2 additions & 0 deletions applications/ColossalChat/coati/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def launch_distributed(
inference_microbatch_size: int,
train_batch_size: int,
train_microbatch_size: int,
pp_batch_size: int,
dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
Expand Down Expand Up @@ -94,6 +95,7 @@ def launch_distributed(
model_config=train_model_config,
plugin_config=plugin_config,
microbatch_size=train_microbatch_size,
pp_batch_size=pp_batch_size,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])
Expand Down
45 changes: 45 additions & 0 deletions applications/ColossalChat/coati/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,51 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
return batch


def split_into_microbatches(data_dict, microbatch_size):
"""
将包含多个张量的字典根据 microbatch_size 切分成多个微批次字典。
:param data_dict: 包含多个张量的字典,input_ids 形状为 (batch_size, seq_len, hidden_dim)
:param microbatch_size: 每个微批次的大小
:return: 微批次字典列表
"""
batch_size = next(iter(data_dict.values())).size(0)
microbatch_dicts = []

for start_idx in range(0, batch_size, microbatch_size):
end_idx = min(start_idx + microbatch_size, batch_size)
microbatch_dict = {}
for key, tensor in data_dict.items():
if tensor.size(0) == batch_size:
microbatch_dict[key] = tensor[start_idx:end_idx]
else:
microbatch_dict[key] = tensor
microbatch_dicts.append(microbatch_dict)

return microbatch_dicts


def cyclic_iter(dataloader):
epoch = 0
while True:
for batch in dataloader:
yield batch
epoch += 1


def filter_microbatch_dicts(microbatch_dicts):
"""
遍历 microbatch_dicts 列表,移除每个字典中键不在 ("input_ids", "attention_mask") 范围内的键值对
:param microbatch_dicts: 包含多个字典的列表
:return: 过滤后的 microbatch_dicts 列表
"""
filtered_dicts = []
allowed_keys = ("input_ids", "attention_mask")
for microbatch_dict in microbatch_dicts:
filtered_dict = {key: value for key, value in microbatch_dict.items() if key in allowed_keys}
filtered_dicts.append(filtered_dict)
return filtered_dicts


def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# compress mask to save bandwidth
if "attention_mask" in batch:
Expand Down
10 changes: 7 additions & 3 deletions applications/ColossalChat/rl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
parser.add_argument("-ppmbs", "--pp-batch-size", type=int, default=8)
parser.add_argument("-b", "--backend", type=str, default="transformers")
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"])
args = parser.parse_args()
Expand All @@ -31,13 +32,15 @@
if args.backend == "transformers":
inference_model_config.update(
dict(
use_flash_attention_2=True,
# use_flash_attention_2=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
)
train_model_config.update(
dict(
use_flash_attention_2=True,
# use_flash_attention_2=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
use_cache=False,
)
Expand Down Expand Up @@ -89,12 +92,13 @@
inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size,
train_microbatch_size=args.train_microbatch_size,
pp_batch_size=args.pp_batch_size,
dataset_config={"path": args.dataset, "max_length": 300},
dataloaders_config={},
inference_model_config=inference_model_config,
generate_config=generate_config,
train_model_config=train_model_config,
plugin_config={},
plugin_config={"tp_size": 1, "pp_size": 2},
inference_backend=args.backend,
master_addr="localhost",
master_port=29504,
Expand Down
10 changes: 6 additions & 4 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,10 +1416,12 @@ def execute_pipeline(
):
return outputs

# Synchronize the grads of shared parameters of the model.
model.sync_shared_params()
# Synchronize sequence parallelism gradients of the model.
model.sync_sp_grads()
# Synchronize when training
if torch.is_grad_enabled():
# Synchronize the grads of shared parameters of the model.
model.sync_shared_params()
# Synchronize sequence parallelism gradients of the model.
model.sync_sp_grads()

# Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so.
# Otherwise, synchronize data parallelism gradients of the model.
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def dist_log_prob(
dtype=dtype,
)
else:
log_prob = log_softmax(logits)
log_prob = log_softmax(logits, dim=-1)
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))

return log_prob
Loading