Skip to content

Commit 8cac35b

Browse files
authored
[Ray] Improve documentation on batch inference (#16609)
Signed-off-by: Richard Liaw <[email protected]>
1 parent 9dbf7a2 commit 8cac35b

File tree

2 files changed

+90
-109
lines changed

2 files changed

+90
-109
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
This example shows how to use Ray Data for data parallel batch inference.
4+
5+
Ray Data is a data processing framework that can handle large datasets
6+
and integrates tightly with vLLM for data-parallel inference.
7+
8+
As of Ray 2.44, Ray Data has a native integration with
9+
vLLM (under ray.data.llm).
10+
11+
Ray Data provides functionality for:
12+
* Reading and writing to cloud storage (S3, GCS, etc.)
13+
* Automatic sharding and load-balancing across a cluster
14+
* Optimized configuration of vLLM using continuous batching
15+
* Compatible with tensor/pipeline parallel inference as well.
16+
17+
Learn more about Ray Data's LLM integration:
18+
https://docs.ray.io/en/latest/data/working-with-llms.html
19+
"""
20+
import ray
21+
from packaging.version import Version
22+
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
23+
24+
assert Version(ray.__version__) >= Version(
25+
"2.44.1"), "Ray version must be at least 2.44.1"
26+
27+
# Uncomment to reduce clutter in stdout
28+
# ray.init(log_to_driver=False)
29+
# ray.data.DataContext.get_current().enable_progress_bars = False
30+
31+
# Read one text file from S3. Ray Data supports reading multiple files
32+
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
33+
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
34+
print(ds.schema())
35+
36+
size = ds.count()
37+
print(f"Size of dataset: {size} prompts")
38+
39+
# Configure vLLM engine.
40+
config = vLLMEngineProcessorConfig(
41+
model_source="unsloth/Llama-3.1-8B-Instruct",
42+
engine_kwargs={
43+
"enable_chunked_prefill": True,
44+
"max_num_batched_tokens": 4096,
45+
"max_model_len": 16384,
46+
},
47+
concurrency=1, # set the number of parallel vLLM replicas
48+
batch_size=64,
49+
)
50+
51+
# Create a Processor object, which will be used to
52+
# do batch inference on the dataset
53+
vllm_processor = build_llm_processor(
54+
config,
55+
preprocess=lambda row: dict(
56+
messages=[{
57+
"role": "system",
58+
"content": "You are a bot that responds with haikus."
59+
}, {
60+
"role": "user",
61+
"content": row["text"]
62+
}],
63+
sampling_params=dict(
64+
temperature=0.3,
65+
max_tokens=250,
66+
)),
67+
postprocess=lambda row: dict(
68+
answer=row["generated_text"],
69+
**row # This will return all the original columns in the dataset.
70+
),
71+
)
72+
73+
ds = vllm_processor(ds)
74+
75+
# Peek first 10 results.
76+
# NOTE: This is for local testing and debugging. For production use case,
77+
# one should write full result out as shown below.
78+
outputs = ds.take(limit=10)
79+
80+
for output in outputs:
81+
prompt = output["prompt"]
82+
generated_text = output["generated_text"]
83+
print(f"Prompt: {prompt!r}")
84+
print(f"Generated text: {generated_text!r}")
85+
86+
# Write inference output data out as Parquet files to S3.
87+
# Multiple files would be written to the output destination,
88+
# and each task would write one or more files separately.
89+
#
90+
# ds.write_parquet("s3://<your-output-bucket>")

examples/offline_inference/distributed.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

0 commit comments

Comments
 (0)