Skip to content

Commit 40e4f8f

Browse files
[feat][merge] Support one-behind to reduce bubble time. Add profiling code
2 parents f9abaa8 + db8baee commit 40e4f8f

File tree

17 files changed

+1421
-234
lines changed

17 files changed

+1421
-234
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,9 @@ applications/ColossalChat/wandb
167167
applications/ColossalChat/model
168168
applications/ColossalChat/eval
169169
applications/ColossalChat/rollouts
170+
applications/ColossalChat/*.txt
171+
applications/ColossalChat/*.db
172+
applications/ColossalChat/stdin
173+
applications/ColossalChat/*.zip
174+
applications/ColossalChat/*.prof
175+
applications/ColossalChat/*.png

applications/ColossalChat/coati/dataset/loader.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,9 @@ def apply_chat_template_and_mask(
367367
}
368368

369369
# Format for RL.
370-
gt_answer = None
371-
if "messages" in chat and "gt_answer" in chat:
372-
gt_answer = chat["gt_answer"]
370+
if "messages" in chat:
371+
gt_answer = chat.get("gt_answer", None)
372+
test_cases = chat.get("test_cases", None)
373373
chat = [chat["messages"]]
374374

375375
tokens = []
@@ -402,12 +402,14 @@ def apply_chat_template_and_mask(
402402
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
403403

404404
if gt_answer is not None:
405-
gt_answer = tokenizer.encode(
406-
gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt"
407-
)
408-
gt_answer = gt_answer.squeeze(1)
409405
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
410-
406+
elif test_cases is not None:
407+
return {
408+
"input_ids": input_ids,
409+
"attention_mask": attention_mask,
410+
"labels": labels,
411+
"test_cases": test_cases,
412+
}
411413
return {
412414
"input_ids": input_ids,
413415
"attention_mask": attention_mask,
@@ -440,3 +442,20 @@ def __getitem__(self, index: int):
440442
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
441443
self.tokenized_texts[index] = dict(tokens)
442444
return self.tokenized_texts[index]
445+
446+
447+
def collate_fn_grpo(batch):
448+
input_ids = [item["input_ids"] for item in batch]
449+
attention_mask = [item["attention_mask"] for item in batch]
450+
labels = [item["labels"] for item in batch]
451+
# Assume input_ids, attention_mask, labels are already of the same length,
452+
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
453+
input_ids = torch.stack(input_ids)
454+
attention_mask = torch.stack(attention_mask)
455+
labels = torch.stack(labels)
456+
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
457+
if "test_cases" in batch[0]:
458+
ret["test_cases"] = [item["test_cases"] for item in batch]
459+
if "gt_answer" in batch[0]:
460+
ret["gt_answer"] = [item["gt_answer"] for item in batch]
461+
return ret

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 116 additions & 50 deletions
Large diffs are not rendered by default.

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 17 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from contextlib import nullcontext
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Optional
33

44
import ray
55
import torch
66
import wandb
77
from coati.distributed.consumer import BaseConsumer
88
from coati.distributed.loss import PolicyLoss
9-
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
10-
from coati.distributed.reward.verifiable_reward import VerifiableReward
11-
from coati.distributed.utils import calc_action_log_probs
9+
from coati.distributed.utils import memory_efficient_logprob
1210
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
1311
from transformers import AutoModelForCausalLM, AutoTokenizer
1412

@@ -40,6 +38,8 @@ def __init__(
4038
project_name: str = None,
4139
run_name: str = None,
4240
wandb_group_name: str = None,
41+
enable_profiling: bool = False,
42+
n_behind: int = 0,
4343
):
4444
print(f"Using GRPO config: {grpo_config}")
4545
if (
@@ -65,6 +65,8 @@ def __init__(
6565
minibatch_size,
6666
save_interval=save_interval,
6767
save_dir=save_dir,
68+
enable_profiling=enable_profiling,
69+
n_behind=n_behind,
6870
)
6971
path = model_config.pop("path")
7072
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -119,20 +121,7 @@ def __init__(
119121
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
120122
)
121123
# Initialize verifiable reward.
122-
response_format_tags = grpo_config.get("response_format_tags", None)
123-
reward_model_kwargs = {
124-
k: v
125-
for k, v in grpo_config.items()
126-
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
127-
}
128-
self.reward_model = VerifiableReward(
129-
reward_fns=[
130-
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
131-
],
132-
tokenizer=self.tokenizer,
133-
tags=response_format_tags,
134-
**reward_model_kwargs,
135-
)
124+
grpo_config.get("response_format_tags", None)
136125
self.global_step = 0
137126

138127
self.lr_scheduler = CosineAnnealingWarmupLR(
@@ -295,12 +284,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
295284
)
296285

297286
if self.booster.plugin.stage_manager.is_last_stage():
298-
reference_model_logits = reference_model_outputs["outputs"]["logits"]
299-
reference_action_log_probs = calc_action_log_probs(
300-
reference_model_logits / self.generate_config["temperature"],
287+
reference_action_log_probs = memory_efficient_logprob(
288+
reference_model_outputs["outputs"]["logits"],
301289
input_ids_forward_micro_batch,
302290
num_action,
303-
self.plugin.shard_config,
291+
shard_config=self.plugin.shard_config,
304292
)
305293
else:
306294
# Dummy reference logprobs for data iterator.
@@ -323,11 +311,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
323311

324312
def _criterion(outputs, inputs):
325313
action_logits = outputs.logits
326-
action_log_probs = calc_action_log_probs(
327-
action_logits / self.generate_config["temperature"],
314+
action_log_probs = memory_efficient_logprob(
315+
action_logits,
328316
inputs["input_ids"],
329317
num_action,
330-
self.plugin.shard_config,
318+
shard_config=self.plugin.shard_config,
331319
)
332320
if "reference_action_log_probs" in inputs:
333321
per_token_kl = (
@@ -370,16 +358,15 @@ def _criterion(outputs, inputs):
370358
mean_kl.append(kl)
371359
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
372360
else:
373-
374361
policy_model_logits = self.policy_model(
375362
input_ids=input_ids_forward_micro_batch,
376363
attention_mask=attention_mask_forward_micro_batch,
377364
).logits
378-
action_log_probs = calc_action_log_probs(
365+
action_log_probs = memory_efficient_logprob(
379366
policy_model_logits / self.generate_config["temperature"],
380367
input_ids_forward_micro_batch,
381368
num_action,
382-
self.plugin.shard_config,
369+
shard_config=self.plugin.shard_config,
383370
)
384371

385372
if self.policy_loss_fn.beta > 0:
@@ -388,11 +375,11 @@ def _criterion(outputs, inputs):
388375
input_ids=input_ids_forward_micro_batch,
389376
attention_mask=attention_mask_forward_micro_batch,
390377
).logits
391-
reference_action_log_probs = calc_action_log_probs(
378+
reference_action_log_probs = memory_efficient_logprob(
392379
reference_model_logits / self.generate_config["temperature"],
393380
input_ids_forward_micro_batch,
394381
num_action,
395-
self.plugin.shard_config,
382+
shard_config=self.plugin.shard_config,
396383
)
397384
per_token_kl = (
398385
torch.exp(reference_action_log_probs - action_log_probs)
@@ -498,40 +485,6 @@ def _criterion(outputs, inputs):
498485
else:
499486
return None
500487

501-
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
502-
"""
503-
Calculate the group reward for the given rollout group.
504-
505-
Args:
506-
rollout_group (Dict[str, Any]):
507-
a group of samples generated by the model from the same prompt
508-
contain the following keys:
509-
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
510-
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
511-
"action_mask": torch.Tensor, [num_of_generation, response_length]
512-
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
513-
"response_idx": int, torch.Tensor, [num_of_generation, 2]
514-
"gt_answer": torch.Tensor, [num_of_generation, 128]
515-
"temperature": torch.Tensor, [] (scalar)
516-
517-
Returns:
518-
Dict[str, Any]: The new group data with calculated reward.
519-
"""
520-
reward_model_output = self.reward_model(
521-
rollout["input_ids"],
522-
gt_answer=rollout["gt_answer"],
523-
response_idx=rollout["response_idx"],
524-
)
525-
# [num_of_generation]
526-
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
527-
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
528-
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
529-
530-
rollout["reward"] = reward.view((-1, 1))
531-
rollout["format_acc"] = format_acc.view((-1, 1))
532-
rollout["ans_acc"] = ans_acc.view((-1, 1))
533-
return rollout
534-
535488
def state_dict(self):
536489
self.policy_model._force_wait_all_gather()
537490
model = self.policy_model.unwrap()

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,8 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
7474
micro_batch_size = input_ids.size(0)
7575
input_ids = input_ids.to(get_current_device())
7676
attention_mask = attention_mask.to(get_current_device())
77-
gt_answer = None
78-
if "gt_answer" in kwargs:
79-
gt_answer = kwargs.pop("gt_answer")
77+
gt_answer = kwargs.pop("gt_answer", None)
78+
test_cases = kwargs.pop("test_cases", None)
8079
if self.num_generations > 1:
8180
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
8281
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
@@ -116,8 +115,9 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
116115
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
117116

118117
if gt_answer is not None:
119-
# repeat gt_answer for each prompt.
120-
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
118+
data["gt_answer"] = gt_answer
119+
if test_cases is not None:
120+
data["test_cases"] = test_cases
121121
data = {k: v.to(get_current_device()) for k, v in data.items()}
122122
return data
123123

@@ -270,11 +270,11 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
270270
}
271271

272272
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
273-
274-
if "gt_answer" in kwargs:
275-
# repeat gt_answer for each prompt.
276-
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
277273
data = {k: v.to(get_current_device()) for k, v in data.items()}
274+
if "gt_answer" in kwargs:
275+
data["gt_answer"] = kwargs["gt_answer"]
276+
if "test_cases" in kwargs:
277+
data["test_cases"] = kwargs["test_cases"]
278278
return data
279279

280280
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def get_jsonl_size_fast(path: str) -> int:
1616
with open(path) as f:
1717
lines = f.readlines()
1818
lines = [line for line in lines if line.strip()]
19-
return len(lines) - 1
19+
return len(lines)
2020

2121

2222
def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
@@ -36,7 +36,6 @@ def launch_distributed(
3636
train_batch_size: int,
3737
train_minibatch_size: int,
3838
train_dataset_config: Dict[str, Any],
39-
dataloaders_config: Dict[str, Any],
4039
inference_model_config: Dict[str, Any],
4140
generate_config: Dict[str, Any],
4241
train_model_config: Dict[str, Any],
@@ -57,6 +56,8 @@ def launch_distributed(
5756
eval_generation_config: Optional[Dict[str, Any]] = None,
5857
log_rollout_interval: int = 20,
5958
rollout_save_dir: str = "./rollout",
59+
enable_profiling: bool = False,
60+
n_behind: int = 0,
6061
):
6162
if core_algo not in ALGO_MAP:
6263
raise NotImplementedError(f"{core_algo} is not supported yet.")
@@ -79,6 +80,11 @@ def launch_distributed(
7980
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
8081
)
8182

83+
# Attention: Ray use complex schedualing method that consider various factors including load-balancing.
84+
# when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
85+
# this go against the design principle of our implementation, and we need to manually force the schedualing,
86+
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
87+
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
8288
nodes = ray.nodes()
8389
node_info = {
8490
node["NodeID"]: {
@@ -104,7 +110,6 @@ def launch_distributed(
104110
gpu_to_node_id.pop(0)
105111
gpu_to_ip_address.pop(0)
106112
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
107-
108113
producer = SimpleProducer.options(
109114
# num_cpus=1,
110115
# num_cpus=num_proc_per_producer,
@@ -121,7 +126,6 @@ def launch_distributed(
121126
num_episodes=num_episodes,
122127
batch_size=inference_batch_size,
123128
train_dataset_config=train_dataset_config,
124-
dataloaders_config=dataloaders_config,
125129
model_config=inference_model_config,
126130
generate_config=generate_config,
127131
tokenizer_config=tokenizer_config,
@@ -131,15 +135,16 @@ def launch_distributed(
131135
consumer_plugin_config=plugin_config,
132136
eval_dataset_config=eval_dataset_config,
133137
eval_interval=eval_interval,
134-
evaluation_function_type=grpo_config["reward_fn_type"],
135-
response_format_tags=grpo_config["response_format_tags"],
138+
grpo_config=grpo_config,
136139
eval_save_dir=eval_save_dir,
137140
eval_generation_config=eval_generation_config,
138141
project_name=project_name,
139142
run_name=run_name,
140143
wandb_group_name=wandb_group_name,
141144
log_rollout_interval=log_rollout_interval,
142145
rollout_log_file=rollout_log_file,
146+
enable_profiling=enable_profiling,
147+
n_behind=n_behind,
143148
)
144149
producer_procs.append(producer)
145150
ray.get([p.setup.remote() for p in producer_procs])
@@ -185,6 +190,8 @@ def launch_distributed(
185190
project_name=project_name,
186191
run_name=run_name,
187192
wandb_group_name=wandb_group_name,
193+
enable_profiling=enable_profiling,
194+
n_behind=n_behind,
188195
)
189196
consumer_procs.append(consumer)
190197
ray.get([p.setup.remote() for p in consumer_procs])

0 commit comments

Comments
 (0)