diff --git a/README.md b/README.md index e23e2c5..95e6496 100644 --- a/README.md +++ b/README.md @@ -250,6 +250,7 @@ Below is a list of all the supported models via `BaseModel` class of `xTuring` a |GPT-2 | gpt2| |LlaMA | llama| |LlaMA2 | llama2| +|Mixtral-8x22B | mixtral| |OPT-1.3B | opt| The above mentioned are the base variants of the LLMs. Below are the templates to get their `LoRA`, `INT8`, `INT8 + LoRA` and `INT4 + LoRA` versions. diff --git a/docs/docs/overview/quickstart/test.jsx b/docs/docs/overview/quickstart/test.jsx index c738db8..0eb7b0a 100644 --- a/docs/docs/overview/quickstart/test.jsx +++ b/docs/docs/overview/quickstart/test.jsx @@ -15,11 +15,12 @@ const modelList = { cerebras: 'Cerebras', distilgpt2: 'DistilGPT-2', galactica: 'Galactica', - gptj: 'GPT-J', + gptj: 'GPT-J', gpt2: 'GPT-2', llama: 'LLaMA', llama2: 'LLaMA 2', opt: 'OPT', + mixtral: 'Mixtral', } export default function Test( @@ -37,7 +38,7 @@ export default function Test( } else { finalKey = `${code.model}_${code.technique}` } - + useEffect(() => { setCode({ model: 'llama', @@ -92,8 +93,8 @@ from xturing.models import BaseModel dataset = ${instruction}Dataset('...') # Load the model -model = BaseModel.create('${finalKey}')`} +model = BaseModel.create('${finalKey}')`} /> ) -} \ No newline at end of file +} diff --git a/docs/docs/overview/supported_models.md b/docs/docs/overview/supported_models.md index 132bd6e..729de3e 100644 --- a/docs/docs/overview/supported_models.md +++ b/docs/docs/overview/supported_models.md @@ -6,7 +6,7 @@ description: Models Supported by xTuring ## Base versions -| Model | Model Key | LoRA | INT8 | LoRA + INT8 | LoRA + INT4 | +| Model | Model Key | LoRA | INT8 | LoRA + INT8 | LoRA + INT4 | | ------ | --- | :---: | :---: | :---: | :---: | | BLOOM 1.1B| bloom | ✅ | ✅ | ✅ | ✅ | | Cerebras 1.3B| cerebras | ✅ | ✅ | ✅ | ✅ | @@ -18,6 +18,7 @@ description: Models Supported by xTuring | LLaMA 7B | llama | ✅ | ✅ | ✅ | ✅ | | LLaMA2 | llama2 | ✅ | ✅ | ✅ | ✅ | | OPT 1.3B | opt | ✅ | ✅ | ✅ | ✅ | +| Mixtral-8x22 | mixtral | ✅ | ✅ | ✅ | | ### Memory-efficient versions > The above mentioned are the base variants of the LLMs. Below are the templates to get their `LoRA`, `INT8`, `INT8 + LoRA` and `INT4 + LoRA` versions. diff --git a/examples/models/mixtral/mixtral.py b/examples/models/mixtral/mixtral.py new file mode 100644 index 0000000..15cdfe4 --- /dev/null +++ b/examples/models/mixtral/mixtral.py @@ -0,0 +1,14 @@ +from xturing.datasets.instruction_dataset import InstructionDataset +from xturing.models import BaseModel + +instruction_dataset = InstructionDataset("./alpaca_data") + +# Initialize the model +model = BaseModel.create("mixtral") + +# Fine-tune the model +model.finetune(dataset=instruction_dataset) + +# Once the model has been fine-tuned, you can start doing inferences +output = model.generate(texts=["Why LLM models are becoming so important?"]) +print("Generated output by the model: {}".format(output)) diff --git a/src/xturing/config/finetuning_config.yaml b/src/xturing/config/finetuning_config.yaml index 3ab1ef9..274b587 100644 --- a/src/xturing/config/finetuning_config.yaml +++ b/src/xturing/config/finetuning_config.yaml @@ -302,6 +302,32 @@ mamba: learning_rate: 5e-5 weight_decay: 0.01 +mixtral: + learning_rate: 5e-5 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 1 + +mixtral_lora: + learning_rate: 1e-4 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 4 + +mixtral_int8: + learning_rate: 1e-4 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 8 + max_length: 256 + +mixtral_lora_int8: + learning_rate: 1e-4 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 8 + max_length: 256 + opt: learning_rate: 5e-5 weight_decay: 0.01 diff --git a/src/xturing/config/generation_config.yaml b/src/xturing/config/generation_config.yaml index 3e472cf..fca7dec 100644 --- a/src/xturing/config/generation_config.yaml +++ b/src/xturing/config/generation_config.yaml @@ -256,6 +256,30 @@ llama2_lora_kbit: mamba: do_sample: false +# Contrastive search +mixtral: + penalty_alpha: 0.6 + top_k: 4 + max_new_tokens: 256 + do_sample: false + +# Contrastive search +mixtral_lora: + penalty_alpha: 0.6 + top_k: 4 + max_new_tokens: 256 + do_sample: false + +# Greedy search +mixtral_int8: + max_new_tokens: 256 + do_sample: false + +# Greedy search +mixtral_lora_int8: + max_new_tokens: 256 + do_sample: false + # Contrastive search opt: penalty_alpha: 0.6 diff --git a/src/xturing/engines/__init__.py b/src/xturing/engines/__init__.py index a97842b..e277f41 100644 --- a/src/xturing/engines/__init__.py +++ b/src/xturing/engines/__init__.py @@ -59,6 +59,12 @@ LlamaLoraKbitEngine, ) from xturing.engines.mamba_engine import MambaEngine +from xturing.engines.mixtral_engine import ( + MixtralEngine, + MixtralInt8Engine, + MixtralLoraEngine, + MixtralLoraInt8Engine, +) from xturing.engines.opt_engine import ( OPTEngine, OPTInt8Engine, @@ -109,6 +115,10 @@ BaseEngine.add_to_registry(LLama2LoraInt8Engine.config_name, LLama2LoraInt8Engine) BaseEngine.add_to_registry(LLama2LoraKbitEngine.config_name, LLama2LoraKbitEngine) BaseEngine.add_to_registry(MambaEngine.config_name, MambaEngine) +BaseEngine.add_to_registry(MixtralEngine.config_name, MixtralEngine) +BaseEngine.add_to_registry(MixtralInt8Engine.config_name, MixtralInt8Engine) +BaseEngine.add_to_registry(MixtralLoraEngine.config_name, MixtralLoraEngine) +BaseEngine.add_to_registry(MixtralLoraInt8Engine.config_name, MixtralLoraInt8Engine) BaseEngine.add_to_registry(OPTEngine.config_name, OPTEngine) BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine) BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine) diff --git a/src/xturing/engines/mixtral_engine.py b/src/xturing/engines/mixtral_engine.py new file mode 100644 index 0000000..783e77e --- /dev/null +++ b/src/xturing/engines/mixtral_engine.py @@ -0,0 +1,63 @@ +from pathlib import Path +from typing import Optional, Union + +from xturing.engines.causal import CausalEngine, CausalLoraEngine + + +class MixtralEngine(CausalEngine): + config_name: str = "mixtral_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + super().__init__( + model_name="mistral-community/Mixtral-8x22B-v0.1", + weights_path=weights_path, + trust_remote_code=True, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + +class MixtralLoraEngine(CausalLoraEngine): + config_name: str = "mixtral_lora_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + super().__init__( + model_name="mistral-community/Mixtral-8x22B-v0.1", + weights_path=weights_path, + target_modules=["q_proj", "v_proj"], + trust_remote_code=True, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + +class MixtralInt8Engine(CausalEngine): + config_name: str = "mixtral_int8_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + super().__init__( + model_name="mistral-community/Mixtral-8x22B-v0.1", + weights_path=weights_path, + load_8bit=True, + trust_remote_code=True, + ) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + +class MixtralLoraInt8Engine(CausalLoraEngine): + config_name: str = "mixtral_lora_int8_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + super().__init__( + model_name="mistral-community/Mixtral-8x22B-v0.1", + weights_path=weights_path, + load_8bit=True, + target_modules=["q_proj", "v_proj"], + trust_remote_code=True, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id diff --git a/src/xturing/models/__init__.py b/src/xturing/models/__init__.py index 611ef2c..70d5bdd 100644 --- a/src/xturing/models/__init__.py +++ b/src/xturing/models/__init__.py @@ -44,6 +44,7 @@ Llama2LoraKbit, ) from xturing.models.mamba import Mamba +from xturing.models.mixtral import Mixtral, MixtralInt8, MixtralLora, MixtralLoraInt8 from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8 from xturing.models.stable_diffusion import StableDiffusion @@ -90,6 +91,10 @@ BaseModel.add_to_registry(Llama2LoraInt8.config_name, Llama2LoraInt8) BaseModel.add_to_registry(Llama2LoraKbit.config_name, Llama2LoraKbit) BaseModel.add_to_registry(Mamba.config_name, Mamba) +BaseModel.add_to_registry(Mixtral.config_name, Mixtral) +BaseModel.add_to_registry(MixtralInt8.config_name, MixtralInt8) +BaseModel.add_to_registry(MixtralLora.config_name, MixtralLora) +BaseModel.add_to_registry(MixtralLoraInt8.config_name, MixtralLoraInt8) BaseModel.add_to_registry(OPT.config_name, OPT) BaseModel.add_to_registry(OPTInt8.config_name, OPTInt8) BaseModel.add_to_registry(OPTLora.config_name, OPTLora) diff --git a/src/xturing/models/mixtral.py b/src/xturing/models/mixtral.py new file mode 100644 index 0000000..d462493 --- /dev/null +++ b/src/xturing/models/mixtral.py @@ -0,0 +1,42 @@ +from typing import Optional + +from xturing.engines.mixtral_engine import ( + MixtralEngine, + MixtralInt8Engine, + MixtralLoraEngine, + MixtralLoraInt8Engine, +) +from xturing.models.causal import ( + CausalInt8Model, + CausalLoraInt8Model, + CausalLoraModel, + CausalModel, +) + + +class Mixtral(CausalModel): + config_name: str = "mixtral" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(MixtralEngine.config_name, weights_path) + + +class MixtralLora(CausalLoraModel): + config_name: str = "mixtral_lora" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(MixtralLoraEngine.config_name, weights_path) + + +class MixtralInt8(CausalInt8Model): + config_name: str = "mixtral_int8" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(MixtralInt8Engine.config_name, weights_path) + + +class MixtralLoraInt8(CausalLoraInt8Model): + config_name: str = "mixtral_lora_int8" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(MixtralLoraInt8Engine.config_name, weights_path)