From 3cc597e306d6d644e830a002f189237f8ed5a298 Mon Sep 17 00:00:00 2001 From: "Wu, Xiaochang" Date: Tue, 28 May 2024 05:47:14 +0000 Subject: [PATCH 1/3] Add callbacks to the Trainer Signed-off-by: Wu, Xiaochang --- llm_on_ray/finetune/finetune.py | 3 +++ llm_on_ray/finetune/finetune_config.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py index a2d6b60f2..809c49555 100644 --- a/llm_on_ray/finetune/finetune.py +++ b/llm_on_ray/finetune/finetune.py @@ -251,6 +251,8 @@ def train_func(config: Dict[str, Any]): tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 ) + callbacks = config["Training"].get("callbacks", None) + if device in ["cpu", "gpu"]: from transformers import Trainer, TrainingArguments @@ -264,6 +266,7 @@ def train_func(config: Dict[str, Any]): else None, tokenizer=tokenizer, data_collator=data_collator, + callbacks=callbacks, ) common.logger.info("train start") diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py index e78600a6d..f705a7e73 100644 --- a/llm_on_ray/finetune/finetune_config.py +++ b/llm_on_ray/finetune/finetune_config.py @@ -16,7 +16,7 @@ from pydantic import BaseModel, validator from typing import Optional, List - +from transformers.trainer_callback import TrainerCallback PRECISION_BF16 = "bf16" PRECISION_FP16 = "fp16" @@ -97,6 +97,10 @@ class Training(BaseModel): gradient_accumulation_steps: int = 1 logging_steps: int = 10 deepspeed_config_file: str = "" + callbacks: Optional[List[TrainerCallback]] = None + + class Config: + arbitrary_types_allowed = True @validator("device") def check_device(cls, v: str): From bb1f1a96f6c993e91f95970c52ed1a4d90f799d9 Mon Sep 17 00:00:00 2001 From: "Wu, Xiaochang" Date: Tue, 28 May 2024 06:04:34 +0000 Subject: [PATCH 2/3] Update import statement in finetune_config.py Signed-off-by: Wu, Xiaochang --- llm_on_ray/finetune/finetune_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py index f705a7e73..aaaf80897 100644 --- a/llm_on_ray/finetune/finetune_config.py +++ b/llm_on_ray/finetune/finetune_config.py @@ -16,7 +16,7 @@ from pydantic import BaseModel, validator from typing import Optional, List -from transformers.trainer_callback import TrainerCallback +from transformers import TrainerCallback PRECISION_BF16 = "bf16" PRECISION_FP16 = "fp16" From 9a745654d5f183971dcf4479898cdfd0c9f64eab Mon Sep 17 00:00:00 2001 From: "Wu, Xiaochang" Date: Tue, 28 May 2024 06:15:39 +0000 Subject: [PATCH 3/3] Update finetune_config.py with pydantic ConfigDict Signed-off-by: Wu, Xiaochang --- llm_on_ray/finetune/finetune_config.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py index aaaf80897..a45890a8c 100644 --- a/llm_on_ray/finetune/finetune_config.py +++ b/llm_on_ray/finetune/finetune_config.py @@ -14,7 +14,7 @@ # limitations under the License. # -from pydantic import BaseModel, validator +from pydantic import BaseModel, ConfigDict, validator from typing import Optional, List from transformers import TrainerCallback @@ -99,8 +99,7 @@ class Training(BaseModel): deepspeed_config_file: str = "" callbacks: Optional[List[TrainerCallback]] = None - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) @validator("device") def check_device(cls, v: str):