Skip to content

[WIP] Arctic inference #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions examples/tic_tac_toe/tic-tac-toe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
import asyncio
import argparse
from dotenv import load_dotenv
from vllm.plugins import load_general_plugins

load_general_plugins()

import art
from art.utils.deploy_model import LoRADeploymentProvider

from rollout import rollout, TicTacToeScenario


Expand Down Expand Up @@ -46,6 +51,22 @@ async def main():
project="tic-tac-toe",
base_model="meta-llama/Meta-Llama-3.1-8B-Instruct",
)
# taken from https://github.com/snowflakedb/ArcticInference?tab=readme-ov-file#offline
model._internal_config = art.dev.InternalModelConfig(
engine_args=art.dev.EngineArgs(
quantization="fp8",
tensor_parallel_size=1,
ulysses_sequence_parallel_size=2,
enable_shift_parallel=True,
speculative_config={
"method": "arctic",
"model": "Snowflake/Arctic-LSTM-Speculator-Llama-3.1-8B-Instruct",
"num_speculative_tokens": 3,
"enable_suffix_decoding": True,
"disable_by_batch_size": 64,
},
)
)

if PULL_FROM_S3:
print("pulling from s3")
Expand All @@ -72,7 +93,7 @@ async def main():
if DEPLOY_MODEL:
print("deploying")
deployment_result = await backend._experimental_deploy(
deploy_to="together",
deploy_to=LoRADeploymentProvider.TOGETHER,
model=model,
step=STEP,
verbose=True,
Expand All @@ -97,8 +118,11 @@ async def main():

print(traj)

if DESTROY_AFTER_RUN:
await backend.down()
if DESTROY_AFTER_RUN and args.backend == "skypilot":
from art.skypilot.backend import SkyPilotBackend

if isinstance(backend, SkyPilotBackend):
await backend.down()

if GENERATE_BENCHMARKS:
gpt_4o_mini = art.Model(
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies = [
"torchao>=0.9.0",
"unsloth==2025.5.1 ; sys_platform == 'linux'",
"unsloth-zoo==2025.5.1 ; sys_platform == 'linux'",
"vllm==0.8.5.post1",
"vllm==0.9.0.1",
"wandb>=0.19.8",
"weave>=0.51.51",
"peft>=0.14.0",
Expand All @@ -27,6 +27,7 @@ dependencies = [
"hf-xet>=1.1.0",
"panza",
"semver>=3.0.4",
"arctic_inference[vllm] @ git+https://github.com/OpenPipe/[email protected] ; sys_platform == 'linux'",
]

[project.optional-dependencies]
Expand Down
6 changes: 6 additions & 0 deletions src/art/dev/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class EngineArgs(TypedDict, total=False):
guided_decoding_backend: str
logits_processor_pattern: str | None
# Speculative decoding configuration.
speculative_config: dict[str, Any] | None
speculative_model: str | None
speculative_model_quantization: str | None
speculative_draft_tensor_parallel_size: int | None
Expand Down Expand Up @@ -125,3 +126,8 @@ class EngineArgs(TypedDict, total=False):
additional_config: dict[str, Any] | None

disable_log_requests: bool

# arctic-inference
ulysses_sequence_parallel_size: int | None
enable_shift_parallel: bool | None
shift_parallel_threshold: int | None
7 changes: 6 additions & 1 deletion src/art/local/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
from vllm.logger import _DATE_FORMAT, _FORMAT
from vllm.utils import FlexibleArgumentParser
from vllm.worker.multi_step_model_runner import MultiStepModelRunner
from vllm.plugins import load_general_plugins

from ..dev.openai_server import OpenAIServerConfig

load_general_plugins()


async def openai_server_task(
engine: EngineClient,
Expand Down Expand Up @@ -274,7 +277,9 @@ async def get_self_lora_tokenizer_async(self, *args, **kwargs):
vllm.transformers_utils.tokenizer_group.get_lora_tokenizer_async = ( # type: ignore
_return_nothing
)
vllm.transformers_utils.tokenizer_group.TokenizerGroup.get_lora_tokenizer_async = get_self_lora_tokenizer_async # type: ignore
vllm.transformers_utils.tokenizer_group.TokenizerGroup.get_lora_tokenizer_async = (
get_self_lora_tokenizer_async # type: ignore
)


def patch_listen_for_disconnect() -> None:
Expand Down