-
Notifications
You must be signed in to change notification settings - Fork 900
[megatron] support reward model #6093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @Jintao-Huang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Megatron-SWIFT framework by integrating support for Reward Model (RM) training. It expands the available RLHF training types to include RM, introduces a new parameter for controlling reward centering, and provides a dedicated trainer class for this purpose. The accompanying documentation has been thoroughly updated to guide users through the new RM training options. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for Reward Model (RM) training to Megatron-SWIFT. The changes include introducing new arguments, updating documentation in both English and Chinese, and adding a new MegatronRewardTrainer
. My review has identified a critical bug that will cause a crash when using the new 'rm' rlhf_type
, and another critical issue where the new trainer has unimplemented placeholder methods. I've also noted several medium-severity issues, including a potential unhandled error case, documentation typos, and unused imports that affect code quality.
def loss_func(self, output_tensor): | ||
pass | ||
|
||
def forward_step(self, data_iterator, model): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docs/source/Megatron-SWIFT/命令行参数.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The punctuation in the list of rlhf_type
options is inconsistent. It uses both a full-width comma and a half-width comma. For consistency, it's better to use half-width commas for all separators in the list.
- 🔥rlhf_type: 默认为'dpo'。目前可选择为'dpo','kto', 'rm'。 | |
- 🔥rlhf_type: 默认为'dpo'。目前可选择为'dpo', 'kto', 'rm'。 |
In addition to inheriting the training parameters, the following parameters are also supported: | ||
|
||
- 🔥rlhf_type: Default is 'dpo'. Currently, 'dpo' and 'kto' are available. | ||
- 🔥rlhf_type: Default is 'dpo'. Currently, 'dpo', kto', and 'rm' are available. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
swift/megatron/train/rlhf.py
Outdated
trainer_mapping = {'dpo': MegatronDPOTrainer, 'kto': MegatronKTOTrainer, 'rm': MegatronRewardTrainer} | ||
trainer_cls = trainer_mapping[args.rlhf_type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While using a dictionary for mapping is a good refactoring, direct key access trainer_mapping[args.rlhf_type]
will raise a KeyError
if an unsupported rlhf_type
is provided. The original implementation raised a more informative ValueError
. It's better to use .get()
and provide a clear error message for unhandled types.
trainer_mapping = {'dpo': MegatronDPOTrainer, 'kto': MegatronKTOTrainer, 'rm': MegatronRewardTrainer} | |
trainer_cls = trainer_mapping[args.rlhf_type] | |
trainer_mapping = {'dpo': MegatronDPOTrainer, 'kto': MegatronKTOTrainer, 'rm': MegatronRewardTrainer} | |
trainer_cls = trainer_mapping.get(args.rlhf_type) | |
if trainer_cls is None: | |
raise ValueError(f'The current Megatron-SWIFT does not support rlhf_type: {args.rlhf_type}.') |
from collections import namedtuple | ||
from functools import partial | ||
from typing import Literal | ||
|
||
import torch | ||
from megatron.core import mpu | ||
from megatron.training import get_args, get_timers | ||
from trl import KTOTrainer | ||
|
||
from swift.utils import get_current_device, get_logger |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for Reward Model (RM) training to Megatron-SWIFT. The changes are comprehensive, including updates to documentation in both English and Chinese, modifications to argument parsing to include RM-specific parameters like center_rewards_coefficient
, and the addition of a new MegatronRewardTrainer
. The implementation of the reward trainer and the refactoring of existing code to accommodate the new rm
rlhf_type are well done. I have a few suggestions to improve code clarity and fix a potential bug.
def __post_init__(self): | ||
if self.rlhf_type == 'rm': | ||
self.task_type = 'seq_cls' | ||
self.num_labels = 1 | ||
super().__post_init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this __post_init__
logic correctly handles the 'rm'
case, the rlhf_type
argument for this class on line 10 is defined as Literal['dpo', 'kto']
. This will prevent the 'rm'
option from being used at all, as it will fail during argument parsing. Please update the Literal
on line 10 to include 'rm'
to make this new logic reachable.
docs/source/Megatron-SWIFT/命令行参数.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a mix of full-width (,
) and half-width (,
) commas in the list of options. For consistency, it's better to use only half-width commas within the code-formatted text to align with common coding style.
- 🔥rlhf_type: 默认为'dpo'。目前可选择为'dpo','kto', 'rm'。 | |
- 🔥rlhf_type: 默认为'dpo'。目前可选择为'dpo', 'kto', 'rm'。 |
if margin is not None: | ||
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - margin).mean() | ||
else: | ||
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loss calculation logic is slightly duplicated for cases with and without a margin. You can simplify this by calculating the reward_diff
first and then optionally subtracting the margin. This refactoring improves readability and reduces code duplication.
if margin is not None: | |
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - margin).mean() | |
else: | |
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() | |
reward_diff = rewards_chosen - rewards_rejected | |
if margin is not None: | |
reward_diff = reward_diff - margin | |
loss = -nn.functional.logsigmoid(reward_diff).mean() |
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for Reward Model (RM) training within the Megatron-SWIFT framework. This includes adding new arguments, a new trainer (MegatronRewardTrainer
), example scripts, and updating documentation. The changes are well-structured and follow existing patterns in the codebase. I have one suggestion to improve the metric collection logic for consistency and to prevent potential memory issues.
if self.args.center_rewards_coefficient is not None: | ||
metric['center_rewards_loss'] = center_rewards_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The center_rewards_loss
tensor is added to the metric
dictionary without being detached. This is inconsistent with how other metrics like loss
are handled (loss.detach().clone()
). While this might not cause an immediate issue if the downstream processing handles it, it's safer and better practice to detach tensors intended for logging to prevent them from being held in the computation graph, which could lead to increased memory usage.
if self.args.center_rewards_coefficient is not None: | |
metric['center_rewards_loss'] = center_rewards_loss | |
if self.args.center_rewards_coefficient is not None: | |
metric['center_rewards_loss'] = center_rewards_loss.detach().clone() |
No description provided.