|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +""" |
| 3 | +This example shows how to use Ray Data for data parallel batch inference. |
| 4 | +
|
| 5 | +Ray Data is a data processing framework that can handle large datasets |
| 6 | +and integrates tightly with vLLM for data-parallel inference. |
| 7 | +
|
| 8 | +As of Ray 2.44, Ray Data has a native integration with |
| 9 | +vLLM (under ray.data.llm). |
| 10 | +
|
| 11 | +Ray Data provides functionality for: |
| 12 | +* Reading and writing to cloud storage (S3, GCS, etc.) |
| 13 | +* Automatic sharding and load-balancing across a cluster |
| 14 | +* Optimized configuration of vLLM using continuous batching |
| 15 | +* Compatible with tensor/pipeline parallel inference as well. |
| 16 | +
|
| 17 | +Learn more about Ray Data's LLM integration: |
| 18 | +https://docs.ray.io/en/latest/data/working-with-llms.html |
| 19 | +""" |
| 20 | +import ray |
| 21 | +from packaging.version import Version |
| 22 | +from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig |
| 23 | + |
| 24 | +assert Version(ray.__version__) >= Version( |
| 25 | + "2.44.1"), "Ray version must be at least 2.44.1" |
| 26 | + |
| 27 | +# Uncomment to reduce clutter in stdout |
| 28 | +# ray.init(log_to_driver=False) |
| 29 | +# ray.data.DataContext.get_current().enable_progress_bars = False |
| 30 | + |
| 31 | +# Read one text file from S3. Ray Data supports reading multiple files |
| 32 | +# from cloud storage (such as JSONL, Parquet, CSV, binary format). |
| 33 | +ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") |
| 34 | +print(ds.schema()) |
| 35 | + |
| 36 | +size = ds.count() |
| 37 | +print(f"Size of dataset: {size} prompts") |
| 38 | + |
| 39 | +# Configure vLLM engine. |
| 40 | +config = vLLMEngineProcessorConfig( |
| 41 | + model_source="unsloth/Llama-3.1-8B-Instruct", |
| 42 | + engine_kwargs={ |
| 43 | + "enable_chunked_prefill": True, |
| 44 | + "max_num_batched_tokens": 4096, |
| 45 | + "max_model_len": 16384, |
| 46 | + }, |
| 47 | + concurrency=1, # set the number of parallel vLLM replicas |
| 48 | + batch_size=64, |
| 49 | +) |
| 50 | + |
| 51 | +# Create a Processor object, which will be used to |
| 52 | +# do batch inference on the dataset |
| 53 | +vllm_processor = build_llm_processor( |
| 54 | + config, |
| 55 | + preprocess=lambda row: dict( |
| 56 | + messages=[{ |
| 57 | + "role": "system", |
| 58 | + "content": "You are a bot that responds with haikus." |
| 59 | + }, { |
| 60 | + "role": "user", |
| 61 | + "content": row["text"] |
| 62 | + }], |
| 63 | + sampling_params=dict( |
| 64 | + temperature=0.3, |
| 65 | + max_tokens=250, |
| 66 | + )), |
| 67 | + postprocess=lambda row: dict( |
| 68 | + answer=row["generated_text"], |
| 69 | + **row # This will return all the original columns in the dataset. |
| 70 | + ), |
| 71 | +) |
| 72 | + |
| 73 | +ds = vllm_processor(ds) |
| 74 | + |
| 75 | +# Peek first 10 results. |
| 76 | +# NOTE: This is for local testing and debugging. For production use case, |
| 77 | +# one should write full result out as shown below. |
| 78 | +outputs = ds.take(limit=10) |
| 79 | + |
| 80 | +for output in outputs: |
| 81 | + prompt = output["prompt"] |
| 82 | + generated_text = output["generated_text"] |
| 83 | + print(f"Prompt: {prompt!r}") |
| 84 | + print(f"Generated text: {generated_text!r}") |
| 85 | + |
| 86 | +# Write inference output data out as Parquet files to S3. |
| 87 | +# Multiple files would be written to the output destination, |
| 88 | +# and each task would write one or more files separately. |
| 89 | +# |
| 90 | +# ds.write_parquet("s3://<your-output-bucket>") |
0 commit comments