Skip to content

[feat] Support Zero Bubble StreamRL-like RL Training #6356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: grpo-latest
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions applications/ColossalChat/coati/distributed/comm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
from typing import Any, Dict

import ray
import ray.util.collective as cc
import torch
import torch.distributed.distributed_c10d as c10d
Expand Down Expand Up @@ -32,7 +34,12 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =


def ray_broadcast_tensor_dict(
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
tensor_dict: Dict[str, torch.Tensor],
src: int = 0,
device=None,
group_name: str = "default",
backend: str = "nccl",
offload_to_cpu: bool = False,
) -> Dict[str, torch.Tensor]:
rank = cc.get_rank(group_name)
if rank == src:
Expand All @@ -46,12 +53,69 @@ def ray_broadcast_tensor_dict(
out_dict = {}
for k, shape, dtype in metadata:
if rank == src:
tensor = tensor_dict[k]
if offload_to_cpu:
tensor = tensor_dict[k].to(device)
else:
tensor = tensor_dict[k]
else:
tensor = torch.empty(shape, dtype=dtype, device=device)
if backend == "gloo" and dtype == torch.bfloat16:
# Gloo does not support bfloat16, convert to float16
tensor = tensor.view(torch.float16)
cc.broadcast(tensor, src, group_name)
if backend == "gloo" and dtype == torch.bfloat16:
# Convert back to bfloat16 if it was converted to float16
tensor = tensor.view(torch.bfloat16)
if rank != src:
out_dict[k] = tensor
if offload_to_cpu:
out_dict[k] = tensor.cpu()
else:
out_dict[k] = tensor
if rank == src:
out_dict = tensor_dict
return out_dict


@ray.remote
class SharedVariableActor:
def __init__(self, number_of_readers: int = 1):
self.data_queue = []
self.data_uid = 0
self.number_of_readers = number_of_readers
self.signals = {}
self.signal_procs_meet_count = {}

def get_queued_data_size(self):
queued_data_size = sum([data[1]["input_ids"].size(0) for data in self.data_queue])
return queued_data_size

def append_data(self, data):
self.data_queue.append([self.data_uid, data, 0]) # [data_uid, data, access_count]
self.data_uid += 1
return True

def get_data(self, data_uid: int):
# for multi-process data reading
if not self.data_queue:
# no data in the queue, return None
return None
to_pop_index = None
ret = None
for i, (uid, data, access_count) in enumerate(self.data_queue):
if uid == data_uid:
# found the data with the given uid
self.data_queue[i][2] += 1
ret = copy.deepcopy(data)
if self.data_queue[i][2] == self.number_of_readers:
to_pop_index = i
break
if to_pop_index is not None:
# remove the data from the queue if it has been accessed by all readers
self.data_queue.pop(to_pop_index)
return ret

def set_signal(self, key: str, signal: str):
self.signals[key] = signal

def get_signal(self):
return self.signals
418 changes: 208 additions & 210 deletions applications/ColossalChat/coati/distributed/consumer.py

Large diffs are not rendered by default.

136 changes: 136 additions & 0 deletions applications/ColossalChat/coati/distributed/distributor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import time

import ray
import ray.util.collective as cc
import torch
from coati.distributed.profiling_utils import CustomProfiler

from colossalai.utils import get_current_device

from .comm import SharedVariableActor, ray_broadcast_tensor_dict


@ray.remote
class Distributor:
def __init__(
self,
distributor_id,
consumer_pp_size,
num_producers,
shared_signal_actor: SharedVariableActor,
enable_profiling: bool = True,
):
self.distributor_id = distributor_id
self.consumer_pp_size = consumer_pp_size
self.state_dict_cpu = {i: {"not_ready_sync_model": torch.ones((1)).cpu()} for i in range(self.consumer_pp_size)}
self.num_producers = num_producers
self.shared_signal_actor = shared_signal_actor
self.device = get_current_device()
self.profiler = CustomProfiler(f"D{self.distributor_id}", disabled=not enable_profiling)
self.weight_version = {i: 0 for i in range(self.consumer_pp_size)}
self.producer_weight_version = {
j: {f"producer_{i}": 0 for i in range(self.num_producers)} for j in range(self.consumer_pp_size)
}

def init_collective_group(
self,
world_size: int,
rank: int,
backend: str = "nccl",
group_name: str = "default",
gloo_timeout: int = 3000000,
):
cc.init_collective_group(
world_size=world_size, rank=rank, backend=backend, group_name=group_name, gloo_timeout=gloo_timeout
)
print(f"[D] Initialized {group_name} collective group", flush=True)

def loop(self):
while True:
time.sleep(1)
signal = ray.get(self.shared_signal_actor.get_signal.remote())
if self.consumer_pp_size > 1:
for i in range(self.consumer_pp_size):
if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model":
self.profiler.enter(f"sync_model_consumer_pp_{i}")
cc.barrier(group_name="distributor_pg")
ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model"))
# Broadcast the model state dict from consumer to shared variable actor
self.state_dict_cpu[i] = ray_broadcast_tensor_dict(
None,
0,
device=torch.device("cpu"),
group_name=f"sync_model_consumer_pp_{i}",
backend="gloo",
)
self.profiler.exit(f"sync_model_consumer_pp_{i}")
self.weight_version[i] += 1
for i in range(self.consumer_pp_size):
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
# Broadcast the model state dict to all producers
ray.get(
self.shared_signal_actor.set_signal.remote(
f"producer_{self.distributor_id}_pp_{i}", "not_ready_sync_model"
)
)
if self.producer_weight_version[i][f"producer_{self.distributor_id}"] < self.weight_version[i]:
self.producer_weight_version[i][f"producer_{self.distributor_id}"] = self.weight_version[i]
ray_broadcast_tensor_dict(
self.state_dict_cpu[i],
1,
device=torch.device("cpu"),
group_name=f"sync_model_producer_{self.distributor_id}_pp_{i}",
backend="gloo",
)
else:
# broadcast a dummy tensor to save the communication cost
ray_broadcast_tensor_dict(
{"not_ready_sync_model": torch.ones((1)).cpu()},
1,
device=torch.device("cpu"),
group_name=f"sync_model_producer_{self.distributor_id}_pp_{i}",
backend="gloo",
)
self.profiler.exit(f"sync_model_producer_{self.distributor_id}_pp_{i}")
else:
if signal.get("consumer", None) == "ready_sync_model":
self.profiler.enter("sync_model_consumer")
cc.barrier(group_name="distributor_pg")
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "not_ready_sync_model"))
# Broadcast the model state dict from consumer to shared variable actor
self.state_dict_cpu = ray_broadcast_tensor_dict(
None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo"
)
self.profiler.exit("sync_model_consumer")
self.weight_version[0] += 1
if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model":
self.profiler.enter(f"sync_model_producer_{self.distributor_id}")
# Broadcast the model state dict to all producers
ray.get(
self.shared_signal_actor.set_signal.remote(
f"producer_{self.distributor_id}", "not_ready_sync_model"
)
)
if self.producer_weight_version[0][f"producer_{self.distributor_id}"] < self.weight_version[0]:
self.producer_weight_version[0][f"producer_{self.distributor_id}"] = self.weight_version[0]
ray_broadcast_tensor_dict(
self.state_dict_cpu,
1,
device=torch.device("cpu"),
group_name=f"sync_model_producer_{self.distributor_id}",
backend="gloo",
)
else:
# broadcast a dummy tensor to save the communication cost
ray_broadcast_tensor_dict(
{"not_ready_sync_model": torch.ones((1)).cpu()},
1,
device=torch.device("cpu"),
group_name=f"sync_model_producer_{self.distributor_id}",
backend="gloo",
)
self.profiler.exit(f"sync_model_producer_{self.distributor_id}")
if signal.get("consumer", None) == "terminate":
self.profiler.log("terminate sync model worker")
break
41 changes: 23 additions & 18 deletions applications/ColossalChat/coati/distributed/grpo_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ray
import torch
import wandb
from coati.distributed.comm import SharedVariableActor
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import memory_efficient_logprob
Expand All @@ -18,14 +19,15 @@
class GRPOConsumer(BaseConsumer):
def __init__(
self,
shared_sync_data_actor: SharedVariableActor,
shared_signal_actor: SharedVariableActor,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
train_dataset_size,
batch_size,
model_config,
plugin_config,
Expand All @@ -38,6 +40,7 @@ def __init__(
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
enable_profiling: bool = False,
):
print(f"Using GRPO config: {grpo_config}")
if (
Expand All @@ -49,20 +52,22 @@ def __init__(
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
)
super().__init__(
shared_sync_data_actor,
shared_signal_actor,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
train_dataset_size,
batch_size,
model_config,
plugin_config,
minibatch_size,
save_interval=save_interval,
save_dir=save_dir,
enable_profiling=enable_profiling,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
Expand Down Expand Up @@ -120,13 +125,6 @@ def __init__(
grpo_config.get("response_format_tags", None)
self.global_step = 0

self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer,
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
warmup_steps=0,
eta_min=0.1 * grpo_config.get("lr", 1e-6),
)

def setup(self):
super().setup()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
Expand All @@ -143,14 +141,21 @@ def setup(self):
group=self.wandb_group_name,
)

self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer,
total_steps=min(self.num_episodes, 4) * self.train_dataset_size // (self.batch_size * self.dp_size),
warmup_steps=0,
eta_min=0.1 * self.grpo_config.get("lr", 1e-6),
)

self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
)
if self.policy_loss_fn.beta > 0:
self.reference_model, *_ = self.booster.boost(self.reference_model)
self.plugin.logger.set_level("ERROR")

def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
def step(self, pbar: Any, **kwargs) -> Optional[float]:
"""
Step data from policy model:
[{
Expand Down Expand Up @@ -212,15 +217,13 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
group_ans_acc_mean < self.filter_range[1],
),
)
self.effective_prompt_count += group_reward.size(0) * self.dp_size
self.effective_prompt_count += (
group_reward.size(0) * self.dp_size
) # all prompts in the batch are effective as we filtered out the bad ones before step.

mean_kl, mean_loss = [], []

if self.grpo_config.get("dynamic_batching", True):
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size

effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
Expand Down Expand Up @@ -428,6 +431,8 @@ def _criterion(outputs, inputs):
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
if self.lr_scheduler is not None:
self.lr_scheduler.step()
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
self.effective_prompt_count = 0
Expand Down
Loading