diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py index a2d6b60f2..809c49555 100644 --- a/llm_on_ray/finetune/finetune.py +++ b/llm_on_ray/finetune/finetune.py @@ -251,6 +251,8 @@ def train_func(config: Dict[str, Any]): tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 ) + callbacks = config["Training"].get("callbacks", None) + if device in ["cpu", "gpu"]: from transformers import Trainer, TrainingArguments @@ -264,6 +266,7 @@ def train_func(config: Dict[str, Any]): else None, tokenizer=tokenizer, data_collator=data_collator, + callbacks=callbacks, ) common.logger.info("train start") diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py index e78600a6d..a45890a8c 100644 --- a/llm_on_ray/finetune/finetune_config.py +++ b/llm_on_ray/finetune/finetune_config.py @@ -14,9 +14,9 @@ # limitations under the License. # -from pydantic import BaseModel, validator +from pydantic import BaseModel, ConfigDict, validator from typing import Optional, List - +from transformers import TrainerCallback PRECISION_BF16 = "bf16" PRECISION_FP16 = "fp16" @@ -97,6 +97,9 @@ class Training(BaseModel): gradient_accumulation_steps: int = 1 logging_steps: int = 10 deepspeed_config_file: str = "" + callbacks: Optional[List[TrainerCallback]] = None + + model_config = ConfigDict(arbitrary_types_allowed=True) @validator("device") def check_device(cls, v: str):