diff --git a/examples/tic_tac_toe/tic-tac-toe.py b/examples/tic_tac_toe/tic-tac-toe.py index b253aed0..4b752d18 100644 --- a/examples/tic_tac_toe/tic-tac-toe.py +++ b/examples/tic_tac_toe/tic-tac-toe.py @@ -3,8 +3,13 @@ import asyncio import argparse from dotenv import load_dotenv +from vllm.plugins import load_general_plugins + +load_general_plugins() import art +from art.utils.deploy_model import LoRADeploymentProvider + from rollout import rollout, TicTacToeScenario @@ -46,6 +51,22 @@ async def main(): project="tic-tac-toe", base_model="meta-llama/Meta-Llama-3.1-8B-Instruct", ) + # taken from https://github.com/snowflakedb/ArcticInference?tab=readme-ov-file#offline + model._internal_config = art.dev.InternalModelConfig( + engine_args=art.dev.EngineArgs( + quantization="fp8", + tensor_parallel_size=1, + ulysses_sequence_parallel_size=2, + enable_shift_parallel=True, + speculative_config={ + "method": "arctic", + "model": "Snowflake/Arctic-LSTM-Speculator-Llama-3.1-8B-Instruct", + "num_speculative_tokens": 3, + "enable_suffix_decoding": True, + "disable_by_batch_size": 64, + }, + ) + ) if PULL_FROM_S3: print("pulling from s3") @@ -72,7 +93,7 @@ async def main(): if DEPLOY_MODEL: print("deploying") deployment_result = await backend._experimental_deploy( - deploy_to="together", + deploy_to=LoRADeploymentProvider.TOGETHER, model=model, step=STEP, verbose=True, @@ -97,8 +118,11 @@ async def main(): print(traj) - if DESTROY_AFTER_RUN: - await backend.down() + if DESTROY_AFTER_RUN and args.backend == "skypilot": + from art.skypilot.backend import SkyPilotBackend + + if isinstance(backend, SkyPilotBackend): + await backend.down() if GENERATE_BENCHMARKS: gpt_4o_mini = art.Model( diff --git a/pyproject.toml b/pyproject.toml index 5e2222ca..80b14ebd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "torchao>=0.9.0", "unsloth==2025.5.1 ; sys_platform == 'linux'", "unsloth-zoo==2025.5.1 ; sys_platform == 'linux'", - "vllm==0.8.5.post1", + "vllm==0.9.0.1", "wandb>=0.19.8", "weave>=0.51.51", "peft>=0.14.0", @@ -27,6 +27,7 @@ dependencies = [ "hf-xet>=1.1.0", "panza", "semver>=3.0.4", + "arctic_inference[vllm] @ git+https://github.com/OpenPipe/ArcticInference.git@v0.0.8 ; sys_platform == 'linux'", ] [project.optional-dependencies] diff --git a/src/art/dev/engine.py b/src/art/dev/engine.py index e3045dfc..a0489810 100644 --- a/src/art/dev/engine.py +++ b/src/art/dev/engine.py @@ -87,6 +87,7 @@ class EngineArgs(TypedDict, total=False): guided_decoding_backend: str logits_processor_pattern: str | None # Speculative decoding configuration. + speculative_config: dict[str, Any] | None speculative_model: str | None speculative_model_quantization: str | None speculative_draft_tensor_parallel_size: int | None @@ -125,3 +126,8 @@ class EngineArgs(TypedDict, total=False): additional_config: dict[str, Any] | None disable_log_requests: bool + + # arctic-inference + ulysses_sequence_parallel_size: int | None + enable_shift_parallel: bool | None + shift_parallel_threshold: int | None diff --git a/src/art/local/vllm.py b/src/art/local/vllm.py index adb2ec98..d32e8849 100644 --- a/src/art/local/vllm.py +++ b/src/art/local/vllm.py @@ -15,9 +15,12 @@ from vllm.logger import _DATE_FORMAT, _FORMAT from vllm.utils import FlexibleArgumentParser from vllm.worker.multi_step_model_runner import MultiStepModelRunner +from vllm.plugins import load_general_plugins from ..dev.openai_server import OpenAIServerConfig +load_general_plugins() + async def openai_server_task( engine: EngineClient, @@ -274,7 +277,9 @@ async def get_self_lora_tokenizer_async(self, *args, **kwargs): vllm.transformers_utils.tokenizer_group.get_lora_tokenizer_async = ( # type: ignore _return_nothing ) - vllm.transformers_utils.tokenizer_group.TokenizerGroup.get_lora_tokenizer_async = get_self_lora_tokenizer_async # type: ignore + vllm.transformers_utils.tokenizer_group.TokenizerGroup.get_lora_tokenizer_async = ( + get_self_lora_tokenizer_async # type: ignore + ) def patch_listen_for_disconnect() -> None: