diff --git a/examples/llama/pippy_llama.py b/examples/llama/pippy_llama.py index 168d47045..e06a30833 100644 --- a/examples/llama/pippy_llama.py +++ b/examples/llama/pippy_llama.py @@ -33,7 +33,7 @@ # Create a pipeline representation from the model mb_inputs = tokenizer(mb_prompts, return_tensors="pt", padding=True).to(device) -pipe = pipeline(llama, mb_args=(mb_inputs["input_ids"],)) +pipe = pipeline(llama, mb_args=(mb_inputs["input_ids"],), split_spec=split_spec) # Create pipeline stage for each rank stage = pipe.build_stage(rank, device=device)