Skip to content

Commit c2cf4a0

Browse files
author
maxtext authors
committed
Merge pull request #1842 from AI-Hypercomputer:sharding_viz
PiperOrigin-RevId: 772704594
2 parents f827a03 + 0c94f82 commit c2cf4a0

File tree

4 files changed

+1556
-0
lines changed

4 files changed

+1556
-0
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""CLI utility for running inference with interleaved prefill and generate."""
16+
17+
import os
18+
import uuid
19+
from typing import Sequence, List
20+
21+
from absl import app
22+
23+
import jax
24+
25+
from MaxText import max_utils, maxengine, pyconfig
26+
27+
_NUM_STREAMS = 5
28+
# How many streams to prefill initially before starting generation.
29+
_INITIAL_PREFILL_STREAMS = 2 # Example: Start generating after 2 streams are ready
30+
31+
32+
def _validate_config(config):
33+
"""Validate configuration settings."""
34+
assert config.load_full_state_path == "", (
35+
"Decode doesn't operate on full states! Convert to parameter checkpoint first." "Using generate_param_only_checkpoint."
36+
)
37+
assert (
38+
0 < _INITIAL_PREFILL_STREAMS <= _NUM_STREAMS
39+
), f"_INITIAL_PREFILL_STREAMS ({_INITIAL_PREFILL_STREAMS}) must be > 0 and <= _NUM_STREAMS ({_NUM_STREAMS})"
40+
41+
42+
def main(argv: Sequence[str]) -> None:
43+
"""Main function to run interleaved inference."""
44+
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
45+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
46+
47+
config = pyconfig.initialize(argv)
48+
_validate_config(config)
49+
max_utils.print_system_information()
50+
51+
engine = maxengine.MaxEngine(config)
52+
rng = jax.random.PRNGKey(1234)
53+
rng, rng_load_params = jax.random.split(rng)
54+
params = engine.load_params(rng=rng_load_params)
55+
56+
text = config.prompt
57+
metadata = engine.get_tokenizer()
58+
tokenizer_model = engine.build_tokenizer(metadata)
59+
tokens, true_length = tokenizer_model.encode(text, is_bos=True, prefill_lengths=[config.max_prefill_predict_length])
60+
assert true_length <= config.max_prefill_predict_length, "Prompt too long for prefill length"
61+
62+
batch_size = int(config.per_device_batch_size * jax.device_count())
63+
assert 0 < _NUM_STREAMS <= batch_size, f"The number of streams {_NUM_STREAMS} must be > 0 and <= batch size {batch_size}"
64+
65+
# Initialize decode state
66+
rng, rng_init_decode = jax.random.split(rng)
67+
decode_state = engine.init_decode_state(rng=rng_init_decode)
68+
print("Initial decode state initialized.")
69+
70+
# Keep track of results per stream (slot)
71+
streams_results: dict[int, List[int]] = {i: [] for i in range(_NUM_STREAMS)}
72+
streams_active: List[bool] = [False] * _NUM_STREAMS # Track which slots are active
73+
streams_finished: List[bool] = [False] * _NUM_STREAMS # Track finished streams
74+
streams_prefilled_count = 0
75+
streams_inserted_count = 0
76+
77+
# --- Initial Prefill Phase ---
78+
print(f"Starting initial prefill for {_INITIAL_PREFILL_STREAMS} streams...")
79+
prefill_results_to_insert = {} # Store prefill results before inserting
80+
for i in range(_INITIAL_PREFILL_STREAMS):
81+
slot_idx = i
82+
print(f" Prefilling stream for slot {slot_idx}...")
83+
rng, rng_prefill = jax.random.split(rng)
84+
request_id = uuid.uuid4()
85+
prefill_result, first_token = engine.prefill(
86+
params=params,
87+
padded_tokens=tokens,
88+
true_length=true_length,
89+
rng=rng_prefill,
90+
slot=slot_idx,
91+
request_id=request_id,
92+
)
93+
prefill_results_to_insert[slot_idx] = prefill_result
94+
streams_results[slot_idx].append(first_token.get_result_at_slot(0).tokens.item())
95+
streams_prefilled_count += 1
96+
print(f"After prefill stream {slot_idx}")
97+
98+
# --- Insert Initial Prefills ---
99+
print("Inserting initial prefill results...")
100+
for slot_idx, prefill_result in prefill_results_to_insert.items():
101+
request_id = uuid.uuid4()
102+
decode_state = engine.insert(
103+
prefix=prefill_result,
104+
decode_state=decode_state,
105+
slot=slot_idx,
106+
request_id=request_id, # Pass request_id
107+
)
108+
streams_active[slot_idx] = True # Mark stream as active
109+
streams_inserted_count += 1
110+
print(f" Inserted prefill for slot {slot_idx}")
111+
112+
print("Starting interleaved generation loop...")
113+
total_steps = config.max_target_length - config.max_prefill_predict_length
114+
for step in range(total_steps):
115+
print(f"\n--- Step {step + 1} / {total_steps} ---")
116+
117+
# Generate step for all active streams
118+
active_stream_indices = [i for i, active in enumerate(streams_active) if active and not streams_finished[i]]
119+
if active_stream_indices:
120+
print(f" Generating for active slots: {active_stream_indices}")
121+
rng, rng_generate = jax.random.split(rng)
122+
decode_state, sampled_tokens = engine.generate(params, decode_state, rng=rng_generate)
123+
124+
# Store the generated token and check for finished streams
125+
for slot_idx in active_stream_indices:
126+
# Check if the stream finished this step
127+
current_len = config.max_prefill_predict_length + step + 1 # Includes prefill + current step
128+
finished_this_step = False
129+
if current_len >= config.max_target_length:
130+
print(f" Stream in slot {slot_idx} reached max target length.")
131+
streams_finished[slot_idx] = True
132+
streams_active[slot_idx] = False
133+
finished_this_step = True
134+
135+
# Store token if it wasn't already finished before this step or if it finished on this step
136+
if not streams_finished[slot_idx] or finished_this_step:
137+
# Ensure we don't try to access results for a slot that might not exist
138+
if slot_idx < sampled_tokens.data.shape[0]:
139+
token_for_slot = sampled_tokens.get_result_at_slot(slot_idx).tokens.item()
140+
streams_results[slot_idx].append(token_for_slot)
141+
else:
142+
print(f"Warning: Tried to get token for slot {slot_idx}, but batch size seems smaller.")
143+
144+
# Call release_pages if finished this step
145+
if finished_this_step:
146+
print(f" Calling engine to release pages for finished slot {slot_idx}...")
147+
engine.release_pages(slot=slot_idx)
148+
149+
else:
150+
print(" No active streams to generate for.")
151+
152+
# 2. Check if all streams are finished (can exit loop early)
153+
if all(streams_finished):
154+
print("\nAll streams finished generation.")
155+
break
156+
157+
# 3. Prefill and Insert new streams if capacity allows
158+
num_active_not_finished = sum(1 for i in range(_NUM_STREAMS) if streams_active[i] and not streams_finished[i])
159+
available_slots = batch_size - num_active_not_finished
160+
can_prefill_more = streams_prefilled_count < _NUM_STREAMS
161+
162+
if can_prefill_more and available_slots > 0:
163+
try:
164+
next_available_slot = streams_active.index(False)
165+
print(f" Prefilling new stream for slot {next_available_slot}...")
166+
rng, rng_prefill = jax.random.split(rng)
167+
request_id = uuid.uuid4()
168+
prefill_result, first_token = engine.prefill(
169+
params=params,
170+
padded_tokens=tokens,
171+
true_length=true_length,
172+
rng=rng_prefill,
173+
slot=next_available_slot,
174+
request_id=request_id,
175+
)
176+
streams_prefilled_count += 1
177+
178+
# Insert the new prefill
179+
print(f" Inserting new stream into slot {next_available_slot}...")
180+
request_id_insert = uuid.uuid4()
181+
decode_state = engine.insert(
182+
prefix=prefill_result,
183+
decode_state=decode_state,
184+
slot=next_available_slot,
185+
request_id=request_id_insert,
186+
)
187+
streams_active[next_available_slot] = True
188+
streams_inserted_count += 1
189+
streams_results[next_available_slot].append(first_token.get_result_at_slot(0).tokens.item())
190+
191+
except ValueError:
192+
print(" Warning: Available slots detected but couldn't find an inactive one.")
193+
elif can_prefill_more:
194+
print(" Generate step finished, but no available slots to prefill new stream.")
195+
else:
196+
print(" Generate step finished, all streams already prefilled.")
197+
198+
print("\n--- Final Results ---")
199+
for i in range(_NUM_STREAMS):
200+
if streams_results[i]:
201+
output = tokenizer_model.decode(streams_results[i])
202+
print(f"Stream {i}: Input=`{text}` -> Output=`{output}`")
203+
204+
if i == 0: # Check first stream as an example
205+
assert output.startswith(
206+
config.autoregressive_decode_assert
207+
), f"Stream {i} generated text mismatch: `{output}` vs expected start `{config.autoregressive_decode_assert}`"
208+
else:
209+
print(f"Stream {i}: Was not activated.")
210+
211+
212+
if __name__ == "__main__":
213+
app.run(main)

MaxText/inference/scripts/notebooks/sharding_utils.ipynb

Lines changed: 411 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)