1
1
from contextlib import nullcontext
2
- from typing import Any , Dict , Optional
2
+ from typing import Any , Optional
3
3
4
4
import ray
5
5
import torch
6
6
import wandb
7
7
from coati .distributed .consumer import BaseConsumer
8
8
from coati .distributed .loss import PolicyLoss
9
- from coati .distributed .reward .reward_fn import boxed_math_reward_fn , math_reward_fn
10
- from coati .distributed .reward .verifiable_reward import VerifiableReward
11
- from coati .distributed .utils import calc_action_log_probs
9
+ from coati .distributed .utils import memory_efficient_logprob
12
10
from coati .trainer .utils import all_reduce_mean , all_reduce_sum
13
11
from transformers import AutoModelForCausalLM , AutoTokenizer
14
12
@@ -40,6 +38,8 @@ def __init__(
40
38
project_name : str = None ,
41
39
run_name : str = None ,
42
40
wandb_group_name : str = None ,
41
+ enable_profiling : bool = False ,
42
+ n_behind : int = 0 ,
43
43
):
44
44
print (f"Using GRPO config: { grpo_config } " )
45
45
if (
@@ -65,6 +65,8 @@ def __init__(
65
65
minibatch_size ,
66
66
save_interval = save_interval ,
67
67
save_dir = save_dir ,
68
+ enable_profiling = enable_profiling ,
69
+ n_behind = n_behind ,
68
70
)
69
71
path = model_config .pop ("path" )
70
72
self .policy_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
@@ -119,20 +121,7 @@ def __init__(
119
121
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
120
122
)
121
123
# Initialize verifiable reward.
122
- response_format_tags = grpo_config .get ("response_format_tags" , None )
123
- reward_model_kwargs = {
124
- k : v
125
- for k , v in grpo_config .items ()
126
- if k in ["soft_over_length_punishment" , "max_new_tokens" , "cache_length" ]
127
- }
128
- self .reward_model = VerifiableReward (
129
- reward_fns = [
130
- math_reward_fn if grpo_config .get ("reward_fn_type" ) == "think_answer_tags" else boxed_math_reward_fn
131
- ],
132
- tokenizer = self .tokenizer ,
133
- tags = response_format_tags ,
134
- ** reward_model_kwargs ,
135
- )
124
+ grpo_config .get ("response_format_tags" , None )
136
125
self .global_step = 0
137
126
138
127
self .lr_scheduler = CosineAnnealingWarmupLR (
@@ -295,12 +284,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
295
284
)
296
285
297
286
if self .booster .plugin .stage_manager .is_last_stage ():
298
- reference_model_logits = reference_model_outputs ["outputs" ]["logits" ]
299
- reference_action_log_probs = calc_action_log_probs (
300
- reference_model_logits / self .generate_config ["temperature" ],
287
+ reference_action_log_probs = memory_efficient_logprob (
288
+ reference_model_outputs ["outputs" ]["logits" ],
301
289
input_ids_forward_micro_batch ,
302
290
num_action ,
303
- self .plugin .shard_config ,
291
+ shard_config = self .plugin .shard_config ,
304
292
)
305
293
else :
306
294
# Dummy reference logprobs for data iterator.
@@ -323,11 +311,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
323
311
324
312
def _criterion (outputs , inputs ):
325
313
action_logits = outputs .logits
326
- action_log_probs = calc_action_log_probs (
327
- action_logits / self . generate_config [ "temperature" ] ,
314
+ action_log_probs = memory_efficient_logprob (
315
+ action_logits ,
328
316
inputs ["input_ids" ],
329
317
num_action ,
330
- self .plugin .shard_config ,
318
+ shard_config = self .plugin .shard_config ,
331
319
)
332
320
if "reference_action_log_probs" in inputs :
333
321
per_token_kl = (
@@ -370,16 +358,15 @@ def _criterion(outputs, inputs):
370
358
mean_kl .append (kl )
371
359
mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
372
360
else :
373
-
374
361
policy_model_logits = self .policy_model (
375
362
input_ids = input_ids_forward_micro_batch ,
376
363
attention_mask = attention_mask_forward_micro_batch ,
377
364
).logits
378
- action_log_probs = calc_action_log_probs (
365
+ action_log_probs = memory_efficient_logprob (
379
366
policy_model_logits / self .generate_config ["temperature" ],
380
367
input_ids_forward_micro_batch ,
381
368
num_action ,
382
- self .plugin .shard_config ,
369
+ shard_config = self .plugin .shard_config ,
383
370
)
384
371
385
372
if self .policy_loss_fn .beta > 0 :
@@ -388,11 +375,11 @@ def _criterion(outputs, inputs):
388
375
input_ids = input_ids_forward_micro_batch ,
389
376
attention_mask = attention_mask_forward_micro_batch ,
390
377
).logits
391
- reference_action_log_probs = calc_action_log_probs (
378
+ reference_action_log_probs = memory_efficient_logprob (
392
379
reference_model_logits / self .generate_config ["temperature" ],
393
380
input_ids_forward_micro_batch ,
394
381
num_action ,
395
- self .plugin .shard_config ,
382
+ shard_config = self .plugin .shard_config ,
396
383
)
397
384
per_token_kl = (
398
385
torch .exp (reference_action_log_probs - action_log_probs )
@@ -498,40 +485,6 @@ def _criterion(outputs, inputs):
498
485
else :
499
486
return None
500
487
501
- def calculate_reward (self , rollout : Dict [str , Any ]) -> Dict [str , Any ]:
502
- """
503
- Calculate the group reward for the given rollout group.
504
-
505
- Args:
506
- rollout_group (Dict[str, Any]):
507
- a group of samples generated by the model from the same prompt
508
- contain the following keys:
509
- "input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
510
- "attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
511
- "action_mask": torch.Tensor, [num_of_generation, response_length]
512
- "action_log_probs": torch.Tensor, [num_of_generation, response_length]
513
- "response_idx": int, torch.Tensor, [num_of_generation, 2]
514
- "gt_answer": torch.Tensor, [num_of_generation, 128]
515
- "temperature": torch.Tensor, [] (scalar)
516
-
517
- Returns:
518
- Dict[str, Any]: The new group data with calculated reward.
519
- """
520
- reward_model_output = self .reward_model (
521
- rollout ["input_ids" ],
522
- gt_answer = rollout ["gt_answer" ],
523
- response_idx = rollout ["response_idx" ],
524
- )
525
- # [num_of_generation]
526
- reward = torch .tensor ([value [0 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
527
- format_acc = torch .tensor ([value [1 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
528
- ans_acc = torch .tensor ([value [2 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
529
-
530
- rollout ["reward" ] = reward .view ((- 1 , 1 ))
531
- rollout ["format_acc" ] = format_acc .view ((- 1 , 1 ))
532
- rollout ["ans_acc" ] = ans_acc .view ((- 1 , 1 ))
533
- return rollout
534
-
535
488
def state_dict (self ):
536
489
self .policy_model ._force_wait_all_gather ()
537
490
model = self .policy_model .unwrap ()
0 commit comments