Skip to content

[Finetune] Integrate DPO trainer for CPU and Gaudi #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 58 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
9148498
add part of DPO code
minmingzhu May 23, 2024
5f8f27f
integrate DPO trainer
minmingzhu May 24, 2024
e350afd
update
minmingzhu May 27, 2024
64df5df
update
minmingzhu May 29, 2024
b7deff5
update
minmingzhu May 29, 2024
dd8b695
1. code format
minmingzhu May 29, 2024
eae05a6
update
minmingzhu May 29, 2024
016c423
add DPO example
minmingzhu May 29, 2024
690eb27
update
minmingzhu May 31, 2024
9e4492a
fix comments
minmingzhu Jun 5, 2024
2b1362b
update
minmingzhu Jun 5, 2024
d9bf7f2
Update huggingface_dataset.py
minmingzhu Jun 5, 2024
d56a4a6
Update finetune_config.py
minmingzhu Jun 5, 2024
1def6b3
update
minmingzhu Jun 15, 2024
3ba941b
Update test_openai_protocol.py
minmingzhu Jun 14, 2024
f8b2a22
1. fix comments
minmingzhu Jun 16, 2024
4bcb687
format
minmingzhu Jun 16, 2024
c425143
update
minmingzhu Jun 16, 2024
6f79859
update
minmingzhu Jun 16, 2024
ccb4bdb
update
minmingzhu Jun 16, 2024
d65b0d6
update
minmingzhu Jun 17, 2024
2804b00
debug CI
minmingzhu Jun 17, 2024
aebc59e
update CI
minmingzhu Jun 17, 2024
83f8837
update CI
minmingzhu Jun 17, 2024
5931fac
update CI
minmingzhu Jun 17, 2024
c9e3636
update CI
minmingzhu Jun 19, 2024
8879dfc
1. update CI
minmingzhu Jun 19, 2024
28ed9c6
update CI
minmingzhu Jun 19, 2024
83e95fe
update doc
minmingzhu Jun 20, 2024
0570bf2
refactor
minmingzhu Jun 20, 2024
a298f52
update
minmingzhu Jul 2, 2024
03dab09
add part of DPO code
minmingzhu May 23, 2024
22ae095
integrate DPO trainer
minmingzhu May 24, 2024
2025ec2
update
minmingzhu May 27, 2024
3f28159
update
minmingzhu May 29, 2024
65de95b
update
minmingzhu May 29, 2024
08b9669
1. code format
minmingzhu May 29, 2024
0786b74
update
minmingzhu May 29, 2024
7072159
add DPO example
minmingzhu May 29, 2024
8f5ad0c
update
minmingzhu May 31, 2024
8e8ec2e
fix comments
minmingzhu Jun 5, 2024
1e07b7a
update
minmingzhu Jun 5, 2024
260980a
Update huggingface_dataset.py
minmingzhu Jun 5, 2024
e9116b5
Update finetune_config.py
minmingzhu Jun 5, 2024
c0dc22f
update
minmingzhu Jun 15, 2024
0d18582
1. fix comments
minmingzhu Jun 16, 2024
66dab6d
update
minmingzhu Jun 16, 2024
45fd5f4
update
minmingzhu Jun 16, 2024
aeffbd5
update
minmingzhu Jun 17, 2024
f4abc4d
update CI
minmingzhu Jun 17, 2024
3ce7d76
update CI
minmingzhu Jun 17, 2024
3d16a45
update CI
minmingzhu Jun 17, 2024
9a09fbc
update CI
minmingzhu Jun 19, 2024
71c47af
1. update CI
minmingzhu Jun 19, 2024
c506398
update doc
minmingzhu Jun 20, 2024
05fdf80
refactor
minmingzhu Jun 20, 2024
fb3152e
update
minmingzhu Jul 2, 2024
007f9c7
update
minmingzhu Jul 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/workflow_finetune.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ jobs:
source dev/scripts/ci-functions.sh
finetune_test ${{ matrix.model }}

- name: Run Finetune DPO Test
run: |
TARGET="finetune"
source dev/scripts/ci-functions.sh
finetune_dpo_test ${{ matrix.model }}

- name: Run PEFT-LoRA Test
run: |
source dev/scripts/ci-functions.sh
Expand Down
15 changes: 15 additions & 0 deletions dev/scripts/ci-functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,21 @@ finetune_test(){
docker exec "finetune" bash -c "llm_on_ray-finetune --config_file llm_on_ray/finetune/finetune.yaml"
}

finetune_dpo_test(){
local model=$1
# Check if the model is 'EleutherAI/gpt-j-6b' or 'gpt2'
if [ "$model" == "EleutherAI/gpt-j-6b" ] || [ "$model" == "gpt2" ]; then
echo "Model '$model' is not supported for this operation."
return
fi
echo Set finetune source config :
docker exec "finetune" bash -c "source \$(python -c 'import oneccl_bindings_for_pytorch as torch_ccl;print(torch_ccl.cwd)')/env/setvars.sh; RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 ray start --head --node-ip-address 127.0.0.1 --ray-debugger-external; RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 ray start --address='127.0.0.1:6379' --ray-debugger-external"
echo Set "${model}" patch_yaml_config :
docker exec "finetune" bash -c "python dev/scripts/patch_yaml_config.py --conf_path "llm_on_ray/finetune/finetune.yaml" --models ${model} --dpo"
echo Stert "${model}" dpo finetune :
docker exec "finetune" bash -c "llm_on_ray-finetune --config_file llm_on_ray/finetune/finetune.yaml"
}

peft_lora_test(){
local model=$1
docker exec "finetune" bash -c "rm -rf /tmp/llm-ray/*"
Expand Down
8 changes: 8 additions & 0 deletions dev/scripts/patch_yaml_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def patch_yaml_config():
parser.add_argument("--conf_path", type=str)
parser.add_argument("--models", type=str)
parser.add_argument("--peft_lora", action="store_true", default=False)
parser.add_argument("--dpo", action="store_true", default=False)

args = parser.parse_args()

conf_path = args.conf_path
Expand Down Expand Up @@ -69,6 +71,12 @@ def patch_yaml_config():
result["General"]["lora_config"]["target_modules"] = None
else:
result["General"]["lora_config"] = None
if args.dpo:
if "finetuning_model" not in result["Training"]:
result["Training"]["finetuning_model"] = {}
result["Dataset"]["train_file"] = "examples/data/sample_dpo_data.jsonl"
result["Training"]["beta"] = 0.1
result["Training"]["finetuning_model"]["dpo"] = True

with open(conf_path, "w") as output:
yaml.dump(result, output, sort_keys=False)
Expand Down
3 changes: 3 additions & 0 deletions docs/finetune_parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ The following are the parameters supported in the finetuning workflow.
| mask_input |True| mask the input part in lables |
| mask_response |True| mask the response part in lables |
| data_preprocess_type |neural_chat| The type of the encode input |
|pad_max|False|Whether to pad the data to the max length of the batch|
|max_source_length|512|Maximum source sequence length. Sequences will be right padded|
|torch_dtype|bfloat16|Override the default `torch.dtype` and load the model under this dtype|


## Training Parameters
Expand Down
100 changes: 100 additions & 0 deletions examples/data/sample_dpo_data.jsonl

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions examples/finetune/mpt-7b/finetune_dpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
General:
base_model: mosaicml/mpt-7b
tokenizer_name: EleutherAI/gpt-neox-20b
gpt_base_model: false
output_dir: /tmp/llm-ray/output
save_strategy: no
config:
trust_remote_code: false
use_auth_token: null
lora_config:
task_type: CAUSAL_LM
r: 8
lora_alpha: 32
lora_dropout: 0.1
enable_gradient_checkpointing: false
Dataset:
train_file: examples/data/sample_dpo_data.jsonl
validation_file: null
validation_split_percentage: 5
Training:
optimizer: adamw_torch
batch_size: 2
epochs: 1
learning_rate: 1.0e-05
lr_scheduler: linear
weight_decay: 0.0
mixed_precision: bf16
device: cpu
num_training_workers: 2
resources_per_worker:
CPU: 32
accelerate_mode: DDP
gradient_accumulation_steps: 1
logging_steps: 10
finetuning_model:
dpo: True
beta: 0.1
116 changes: 116 additions & 0 deletions llm_on_ray/finetune/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from itertools import chain

import torch
from typing import Dict

IGNORE_INDEX = -100

Expand Down Expand Up @@ -218,3 +219,118 @@ def tokenize(self, examples):
examples["labels"].append(labels)
examples["attention_mask"].append(results["attention_mask"])
return examples


class DPOIntelOrcaProcessor(DataProcessor):
def __init__(self, config, tokenizer):
self.tokenizer = tokenizer
self.end = tokenizer.eos_token
self.config = config

def make_prompt(self, examples):
prompts = {}
prompts["prompt"] = []
prompts["chosen"] = []
prompts["rejected"] = []

for rec in examples:
prompts["prompt"].append(
" ".join(
[system + question for system, question in zip(rec["system"], rec["question"])]
)
)
prompts["chosen"].append(rec["chosen"])
prompts["rejected"].append(rec["rejected"])
return prompts

"""
Copied from https://github.com/intel/intel-extension-for-transformers/blob/5ba5fa8048b63bec8a3be8a7122a3db8344ad065/
intel_extension_for_transformers/neural_chat/examples/finetuning/dpo_pipeline/dpo_clm.py#L308
"""

def tokenize(self, examples):
prompts = {p.strip() for p in examples["prompt"]}
chosens = {c.strip() for c in examples["chosen"]}
rejects = {r.strip() for r in examples["rejected"]}

examples = {
"prompt": [],
"chosen": [],
"rejected": [],
"chosen_response_only": [],
"rejected_response_only": [],
"chosen_input_ids": [],
"chosen_attention_mask": [],
"chosen_labels": [],
"rejected_input_ids": [],
"rejected_attention_mask": [],
"rejected_labels": [],
"prompt_input_ids": [],
"prompt_attention_mask": [],
}

for prompt, chosen, reject in zip(prompts, chosens, rejects):
prompt_tokens = self.tokenizer.tokenize(prompt, return_tensors="pt")

if len(prompt_tokens) > self.config["Dataset"]["max_source_length"]:
prompt_tokens = prompt_tokens[: self.config["Dataset"]["max_source_length"]]

prompt_ids = self.tokenizer.convert_tokens_to_ids(prompt_tokens)
prompt_mask = [1] * len(prompt_ids)

max_resp = self.config["Dataset"]["max_length"] - len(prompt_ids)
chosen_tokens = self.tokenizer.tokenize(chosen)
chosen_tokens = chosen_tokens[: max_resp - 1]
chosen_tokens.append(self.end)
chosen_ids = self.tokenizer.convert_tokens_to_ids(chosen_tokens)
chosen_mask = [1] * len(chosen_ids)

reject_tokens = self.tokenizer.tokenize(reject)
reject_tokens = reject_tokens[: max_resp - 1]
reject_tokens.append(self.end)
reject_ids = self.tokenizer.convert_tokens_to_ids(reject_tokens)
reject_mask = [1] * len(reject_ids)

chosen_input_ids = prompt_ids + chosen_ids
chosen_attention_mask = prompt_mask + chosen_mask
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids

reject_input_ids = prompt_ids + reject_ids
reject_attention_mask = prompt_mask + reject_mask
reject_labels = [IGNORE_INDEX] * len(prompt_ids) + reject_ids

# padding
input_len = len(chosen_input_ids)
if self.config["Dataset"]["pad_max"]:
pad_len = self.config["Dataset"]["max_length"] - input_len
chosen_input_ids = chosen_input_ids + [0] * pad_len
chosen_labels = chosen_labels + [-100] * pad_len
chosen_attention_mask = chosen_attention_mask + [0] * pad_len
assert len(chosen_input_ids) == self.config["Dataset"]["max_length"]

input_len = len(reject_input_ids)
if self.config["Dataset"]["pad_max"]:
pad_len = self.config["Dataset"]["max_length"] - input_len
reject_input_ids = reject_input_ids + [0] * pad_len
reject_labels = reject_labels + [-100] * pad_len
reject_attention_mask = reject_attention_mask + [0] * pad_len
assert len(reject_input_ids) == self.config["Dataset"]["max_length"]

examples["prompt"].append(prompt)
examples["chosen"].append(prompt + chosen)
examples["rejected"].append(prompt + reject)
examples["chosen_response_only"].append(chosen)
examples["rejected_response_only"].append(reject)

examples["chosen_input_ids"].append(chosen_input_ids)
examples["chosen_attention_mask"].append(chosen_attention_mask)
examples["chosen_labels"].append(chosen_labels)

examples["rejected_input_ids"].append(reject_input_ids)
examples["rejected_attention_mask"].append(reject_attention_mask)
examples["rejected_labels"].append(reject_labels)

examples["prompt_input_ids"].append(prompt_ids)
examples["prompt_attention_mask"].append(prompt_mask)

return examples
165 changes: 165 additions & 0 deletions llm_on_ray/finetune/dpo_finetuing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#
# Copyright 2023 The LLM-on-Ray Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datasets
import torch
import transformers
from peft import LoraConfig
from transformers import AutoModelForCausalLM
from typing import Dict

from itertools import chain

from llm_on_ray.finetune.data_process import DPOIntelOrcaProcessor
from llm_on_ray.finetune.finetuning import Finetuning

IGNORE_INDEX = -100


class DPOFineTuning(Finetuning):
def load_tokenizer(self, config: Dict):
if config["General"].get("tokenizer_name") is not None:
tokenizer_name = config["General"].get("tokenizer_name")
else:
tokenizer_name = config["General"]["base_model"]
load_config = config["General"].get("config", {})
tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer_name,
**load_config,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer

def tokenize_dataset(self, config: Dict, tokenizer, dataset):
processor = DPOIntelOrcaProcessor(config, tokenizer)
for key in dataset:
prompts = processor.make_prompt(dataset[key])
dataset[key] = datasets.Dataset.from_dict(prompts)

train_dataset = dataset["train"]
column_names = list(train_dataset.features)
if train_dataset is not None:
# Create train feature from dataset
train_dataset = train_dataset.map(
processor.tokenize,
batched=True,
remove_columns=column_names,
desc="Running tokenizer on train dataset",
)

eval_dataset = dataset.get("validation")

if eval_dataset is not None:
column_names = eval_dataset.column_names
eval_dataset = eval_dataset.map(
processor.tokenize,
batched=True,
remove_columns=column_names,
desc="Running tokenizer on validation dataset",
)
tokenized_datasets = {"train": train_dataset, "validation": eval_dataset}

return tokenized_datasets

def load_model(self, config: Dict):
model_name = config["General"]["base_model"]
model_dtype = self.convert_dtype(config["Training"].get("mixed_precision", "no"))
model_config = config["General"].get("config", {})
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=model_dtype, **model_config
)

egc = config["General"].get("enable_gradient_checkpointing", False)
if egc:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.config.use_cache = False

model.to(dtype=model_dtype, device=torch.device(config["Training"]["device"]))

return model

def load_model_ref(self, config: Dict):
model_name = config["General"]["base_model"]
model_dtype = self.convert_dtype(config["Training"].get("mixed_precision", "no"))
model_config = config["General"].get("config", {})

# load reference model
model_ref = transformers.AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=model_dtype, **model_config
)

model_ref.config.use_cache = False
model_ref.to(dtype=model_dtype, device=torch.device(config["Training"]["device"]))

return model_ref

def get_trainer(self, config: Dict, model, tokenizer, tokenized_dataset, data_collator):
device = config["Training"]["device"]
lora_config = config["General"].get("lora_config", None)

if device in ["cpu", "gpu"]:
from transformers import Trainer, TrainingArguments
from trl import DPOTrainer

training_args = self.convert_to_training_args(TrainingArguments, config)

trainer = DPOTrainer(
model,
self.load_model_ref(config) if lora_config is not None else None,
args=training_args,
beta=config["Training"].get("beta"),
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"]
if tokenized_dataset.get("validation") is not None
else None,
tokenizer=tokenizer,
peft_config=LoraConfig(**lora_config) if lora_config is not None else None,
max_length=config["Dataset"].get("max_length"),
max_prompt_length=config["Dataset"].get("max_prompt_length"),
)
elif device in ["hpu"]:
from optimum.habana.trl import GaudiDPOTrainer as DPOTrainer
from optimum.habana.transformers import GaudiTrainingArguments
from optimum.habana import GaudiConfig

# If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config
gaudi_config_name = config["General"].get("gaudi_config_name", None)
if gaudi_config_name is not None:
gaudi_config = GaudiConfig.from_pretrained(gaudi_config_name)
else:
gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
gaudi_config.use_fused_clip_norm = True

training_args = self.convert_to_training_args(GaudiTrainingArguments, config)
trainer = DPOTrainer(
model,
self.load_model_ref(config) if lora_config is not None else None,
args=training_args,
gaudi_config=gaudi_config,
beta=config["Training"].get("beta"),
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"]
if tokenized_dataset.get("validation") is not None
else None,
tokenizer=tokenizer,
peft_config=LoraConfig(**lora_config) if lora_config is not None else None,
max_length=config["Dataset"].get("max_length"),
max_prompt_length=config["Dataset"].get("max_prompt_length"),
)

return training_args, trainer
Loading
Loading