Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 35 additions & 24 deletions cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,28 @@ class DecoderInputBuffers
using SizeType32 = runtime::SizeType32;
using TensorPtr = runtime::ITensor::SharedPtr;

explicit DecoderInputBuffers(
SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, runtime::BufferManager const& manager);
explicit DecoderInputBuffers(SizeType32 maxNumSequences, SizeType32 maxBatchSize, SizeType32 maxDecoderSteps,
runtime::BufferManager const& manager);

// buffers for setup
//! Buffers for decoder setup

//! Input IDs of new requests, [maxBatchSize]
TensorPtr inputsIds;
//! Batch slots for setup step, [maxBatchSize]
TensorPtr setupBatchSlots;
TensorPtr setupBatchSlotsDevice;
//! Helper buffer for copying sequence lengths, [maxBatchSize]
TensorPtr fillValues;
TensorPtr fillValuesDevice;

// buffers for forward
//! Buffers for decoder forward

//! Batch slots for all decoder steps, [maxDecoderSteps][maxBatchSize]
std::vector<TensorPtr> forwardBatchSlots;

//! Logits for all batch slots, [maxNumSequences]
//! The vector is sparse, only slots in forwardBatchSlots are used.
std::vector<TensorPtr> logits;
};

class DecoderOutputBuffers
Expand All @@ -70,35 +80,36 @@ class DecoderOutputBuffers
TensorPtr finishReasonsHost; // [mMaxNumRequests, beamWidth], pinned host tensor
};

class DecoderBuffers
class DraftBuffers
{
public:
using SizeType32 = runtime::SizeType32;
using TensorPtr = runtime::ITensor::SharedPtr;

std::vector<TensorPtr> logits;
TensorPtr nextDraftTokensDevice; // [mMaxNumRequests, maxTokensPerStep-1]
TensorPtr nextDraftTokensHost; // [mMaxNumRequests, maxTokensPerStep-1]
TensorPtr prevDraftTokensLengthsDevice; // [mMaxNumRequests]
TensorPtr prevDraftTokensLengthsHost; // [mMaxNumRequests]
TensorPtr nextDraftTokensLengthsDevice; // [mMaxNumRequests]
TensorPtr nextDraftTokensLengthsHost; // [mMaxNumRequests]
TensorPtr acceptedLengthsCumSumDevice; // [mMaxNumRequests+1]
TensorPtr acceptedPackedPathsDevice; // [mMaxNumRequests * maxAcceptedTokens]
std::vector<std::vector<runtime::ITensor::SharedPtr>>
predictedDraftLogits; // [mMaxNumRequests][mMaxNumHeads][maxDraftTokens + 1, vocabSize]

void create(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, runtime::BufferManager const& manager,
runtime::ModelConfig const& modelConfig);
};

class DecoderBuffers
{
public:
using SizeType32 = runtime::SizeType32;
using TensorPtr = runtime::ITensor::SharedPtr;

TensorPtr cacheIndirectionInput;
TensorPtr cacheIndirectionOutput;

class DraftBuffers
{
public:
TensorPtr nextDraftTokensDevice; // [mMaxNumRequests, maxTokensPerStep-1]
TensorPtr nextDraftTokensHost; // [mMaxNumRequests, maxTokensPerStep-1]
TensorPtr prevDraftTokensLengthsDevice; // [mMaxNumRequests]
TensorPtr prevDraftTokensLengthsHost; // [mMaxNumRequests]
TensorPtr nextDraftTokensLengthsDevice; // [mMaxNumRequests]
TensorPtr nextDraftTokensLengthsHost; // [mMaxNumRequests]
TensorPtr acceptedLengthsCumSumDevice; // [mMaxNumRequests+1]
TensorPtr acceptedPackedPathsDevice; // [mMaxNumRequests * maxAcceptedTokens]
std::vector<std::vector<runtime::ITensor::SharedPtr>>
predictedDraftLogits; // [mMaxNumRequests][mMaxNumHeads][maxDraftTokens + 1, vocabSize]

void create(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, runtime::BufferManager const& manager,
runtime::ModelConfig const& modelConfig);
};

DraftBuffers draftBuffers;

DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
Expand Down
12 changes: 6 additions & 6 deletions cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class CudaStream;
namespace tensorrt_llm::batch_manager
{

class RuntimeBuffers;
class DecoderBuffers;
class DecoderInputBuffers;
class DraftBuffers;
class MedusaBuffers;

namespace tr = tensorrt_llm::runtime;
Expand All @@ -47,10 +47,10 @@ class HandleContextLogits : Algorithm

HandleContextLogits() = default;

tr::SizeType32 operator()(RequestVector const& contextRequests,
std::vector<tr::SizeType32> const& numContextLogitsVec, tr::ITensor::SharedPtr const& logits,
DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, tr::BufferManager const& manager,
tensorrt_llm::runtime::CudaStream const& stream, OptionalRef<MedusaBuffers> medusaBuffers) const;
tr::SizeType32 operator()(DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests,
tr::ITensor::SharedPtr const& logits, std::vector<tr::SizeType32> const& numContextLogitsVec,
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, OptionalRef<DraftBuffers> draftBuffers,
OptionalRef<MedusaBuffers> medusaBuffers) const;
};

} // namespace tensorrt_llm::batch_manager
10 changes: 6 additions & 4 deletions cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ class BufferManager;
namespace tensorrt_llm::batch_manager
{

class DecoderInputBuffers;
class DraftBuffers;
class RuntimeBuffers;
class DecoderBuffers;

namespace tr = tensorrt_llm::runtime;

Expand All @@ -45,9 +46,10 @@ class HandleGenerationLogits : Algorithm

HandleGenerationLogits() = default;

void operator()(tr::SizeType32 logitsIndex, RequestVector const& generationRequests, DecoderBuffers& decoderBuffers,
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, tr::ITensor::SharedPtr const& logits,
OptionalRef<RuntimeBuffers> genRuntimeBuffers) const;
void operator()(DecoderInputBuffers& inputBuffers, RequestVector const& generationRequests,
tr::ITensor::SharedPtr const& logits, tr::SizeType32 logitsIndex, tr::ModelConfig const& modelConfig,
tr::BufferManager const& manager, OptionalRef<RuntimeBuffers> genRuntimeBuffers,
OptionalRef<DraftBuffers> draftBuffers) const;
};

} // namespace tensorrt_llm::batch_manager
8 changes: 2 additions & 6 deletions cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ class TllmRuntime;
namespace tensorrt_llm::batch_manager
{

class DecoderBuffers;

namespace tr = tensorrt_llm::runtime;

class LogitsPostProcessor : Algorithm
{
public:
Expand All @@ -48,8 +44,8 @@ class LogitsPostProcessor : Algorithm
LogitsPostProcessor() = default;

bool operator()(RequestVector const& contextRequests, RequestVector const& generationRequests,
bool replicateLogitsPostProcessor, DecoderBuffers& decoderBuffers, tr::WorldConfig const& worldConfig,
tr::TllmRuntime& runtime,
bool replicateLogitsPostProcessor, std::vector<batch_manager::LlmRequest::TensorPtr>& seqSlotLogits,
runtime::WorldConfig const& worldConfig, runtime::TllmRuntime& runtime,
std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched = std::nullopt) const;
};

Expand Down
13 changes: 5 additions & 8 deletions cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace tensorrt_llm::batch_manager
{

DecoderInputBuffers::DecoderInputBuffers(
SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, BufferManager const& manager)
SizeType32 maxNumSequences, SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, BufferManager const& manager)
{
auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize});
auto const nvSizeType = TRTDataType<SizeType32>::value;
Expand All @@ -48,6 +48,8 @@ DecoderInputBuffers::DecoderInputBuffers(
{
forwardBatchSlots.emplace_back(BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize}), nvSizeType));
}

logits.resize(maxNumSequences);
}

DecoderOutputBuffers::DecoderOutputBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxSeqLen,
Expand Down Expand Up @@ -91,11 +93,6 @@ DecoderBuffers::DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWid
SizeType32 maxTokensPerStep, BufferManager const& manager, ModelConfig const& modelConfig,
WorldConfig const& worldConfig)
{
if (worldConfig.isLastPipelineParallelRank())
{
logits.resize(maxNumSequences);
}

cacheIndirectionInput = manager.gpu(
ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow}), nvinfer1::DataType::kINT32);
cacheIndirectionOutput = manager.gpu(
Expand All @@ -109,8 +106,8 @@ DecoderBuffers::DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWid
}
}

void DecoderBuffers::DraftBuffers::create(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep,
BufferManager const& manager, ModelConfig const& modelConfig)
void DraftBuffers::create(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, BufferManager const& manager,
ModelConfig const& modelConfig)
{
auto const speculativeDecodingMode = modelConfig.getSpeculativeDecodingMode();

Expand Down
13 changes: 7 additions & 6 deletions cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ void setupMedusaLogits(std::vector<TensorPtr>& medusaLogitsHeads, TensorPtr cons

} // namespace

SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests,
std::vector<SizeType32> const& numContextLogitsVec, TensorPtr const& logits, DecoderBuffers& decoderBuffers,
tr::ModelConfig const& modelConfig, BufferManager const& manager, tensorrt_llm::runtime::CudaStream const& stream,
SizeType32 HandleContextLogits::operator()(DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests,
tr::ITensor::SharedPtr const& logits, std::vector<tr::SizeType32> const& numContextLogitsVec,
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, OptionalRef<DraftBuffers> draftBuffers,
OptionalRef<MedusaBuffers> medusaBuffers) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
Expand Down Expand Up @@ -114,13 +114,14 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests,
// Get the logits from the last context token and draft tokens
auto const numDecoderLogits = 1 + draftLength;
auto const seqSlot = llmReq->mSeqSlot.value();
auto& decoderLogits = decoderBuffers.logits.at(seqSlot);
auto& decoderLogits = inputBuffers.logits.at(seqSlot);
TensorPtr logitsView = ITensor::slice(logits, logitsIndex - numDecoderLogits, numDecoderLogits);

if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits())
{
TLLM_CHECK(draftBuffers);
auto& medusaLogitsHeads = draftBuffers->predictedDraftLogits.at(seqSlot);
TLLM_CHECK(medusaBuffers);
auto& medusaLogitsHeads = decoderBuffers.draftBuffers.predictedDraftLogits.at(seqSlot);
setupMedusaLogits(medusaLogitsHeads, medusaBuffers->medusaLogitsDevice,
modelConfig.getSpeculativeDecodingModule().getMaxDraftPathLen(), logitsIndex - numDecoderLogits,
numDecoderLogits);
Expand All @@ -143,7 +144,7 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests,
auto const logitsShape = logitsView->getShape();
auto const logitsType = logitsView->getDataType();
decoderLogits = manager.gpu(ITensor::makeShape({reqBeamWidth, logitsShape.d[1]}), logitsType);
tensorrt_llm::runtime::kernels::tileTensor(*decoderLogits, *logitsView, reqBeamWidth, stream);
tensorrt_llm::runtime::kernels::tileTensor(*decoderLogits, *logitsView, reqBeamWidth, manager.getStream());
decoderLogits->unsqueeze(0);
}
else
Expand Down
13 changes: 8 additions & 5 deletions cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ void setupMedusaLogits(std::vector<TensorPtr>& medusaLogitsHeads, TensorPtr cons

} // namespace

void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector const& generationRequests,
DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, BufferManager const& manager,
TensorPtr const& logits, OptionalRef<RuntimeBuffers> genRuntimeBuffers) const
void HandleGenerationLogits::operator()(DecoderInputBuffers& inputBuffers, RequestVector const& generationRequests,
tr::ITensor::SharedPtr const& logits, tr::SizeType32 logitsIndex, tr::ModelConfig const& modelConfig,
tr::BufferManager const& manager, OptionalRef<RuntimeBuffers> genRuntimeBuffers,
OptionalRef<DraftBuffers> draftBuffers) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(HandleGenerationLogits);
Expand All @@ -99,7 +100,7 @@ void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector co
TensorPtr logitsView = ITensor::slice(logits, logitsIndex, numLogits);
TLLM_CHECK_DEBUG_WITH_INFO(tru::tensorHasInvalid<float>(*logitsView, manager, "logits") == false,
"Found invalid number (NaN or Inf) in logits");
auto& decoderLogits = decoderBuffers.logits.at(seqSlot);
auto& decoderLogits = inputBuffers.logits.at(seqSlot);
auto const logitsViewShape = logitsView->getShape();
if (reqBeamWidth > 1)
{
Expand Down Expand Up @@ -136,8 +137,10 @@ void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector co
}
if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits())
{
TLLM_CHECK(draftBuffers);
auto& medusaLogitsHeads = draftBuffers->predictedDraftLogits.at(seqSlot);
TLLM_CHECK(genRuntimeBuffers);
auto& medusaLogitsHeads = decoderBuffers.draftBuffers.predictedDraftLogits.at(seqSlot);
TLLM_CHECK(genRuntimeBuffers->mMedusaBuffers);
setupMedusaLogits(medusaLogitsHeads, genRuntimeBuffers->mMedusaBuffers->medusaLogitsDevice,
modelConfig.getSpeculativeDecodingModule().getMaxDraftPathLen(), logitsIndex, draftLength);
}
Expand Down
8 changes: 5 additions & 3 deletions cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "tensorrt_llm/runtime/tllmRuntime.h"
#include "tensorrt_llm/runtime/utils/debugUtils.h"

namespace tr = tensorrt_llm::runtime;

namespace tensorrt_llm::batch_manager
{

Expand All @@ -34,7 +36,7 @@ using ITensor = runtime::ITensor;
using SizeType32 = tensorrt_llm::runtime::SizeType32;

bool LogitsPostProcessor::operator()(RequestVector const& contextRequests, RequestVector const& generationRequests,
bool replicateLogitsPostProcessor, DecoderBuffers& decoderBuffers, tr::WorldConfig const& worldConfig,
bool replicateLogitsPostProcessor, std::vector<TensorPtr>& seqSlotLogits, tr::WorldConfig const& worldConfig,
tr::TllmRuntime& runtime, std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
Expand All @@ -59,7 +61,7 @@ bool LogitsPostProcessor::operator()(RequestVector const& contextRequests, Reque
logitsPostProcessorIsApplied = true;
if (replicateLogitsPostProcessor || worldConfig.isFirstTensorParallelRank())
{
auto& logits = decoderBuffers.logits.at(llmReq->mSeqSlot.value());
auto& logits = seqSlotLogits.at(llmReq->mSeqSlot.value());
(*llmReq->mLogitsPostProcessor)(
llmReq->mRequestId, logits, llmReq->getTokens(), runtime.getStreamPtr(), llmReq->mClientId);
}
Expand All @@ -68,7 +70,7 @@ bool LogitsPostProcessor::operator()(RequestVector const& contextRequests, Reque
{
reqIdsVec.push_back(llmReq->mRequestId);

auto& logits = decoderBuffers.logits.at(llmReq->mSeqSlot.value());
auto& logits = seqSlotLogits.at(llmReq->mSeqSlot.value());
logitsVec.push_back(logits);

beamTokensVec.emplace_back(llmReq->getTokens());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ MakeDecodingBatchInputOutput::operator()(RequestVector const& contextRequests, R

auto [activeSlots, generationSteps] = getActiveSlots(contextRequests, generationRequests);

auto decodingInput = createDecoderBatchInputs(activeSlots, decoderState, decoderBuffers.logits, maxNumSequences,
auto decodingInput = createDecoderBatchInputs(activeSlots, decoderState, inputBuffers.logits, maxNumSequences,
inputBuffers.forwardBatchSlots, decoderBuffers.cacheIndirectionInput);
decodingInput->generationSteps = generationSteps;

Expand Down
19 changes: 11 additions & 8 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,7 @@ void TrtGptModelInflightBatching::createBuffers(executor::DecodingConfig const&
for (SizeType32 i = 0; i < mNumMicroBatches; ++i)
{
mDecoderInputBuffers.emplace_back(
getMaxBatchSize(), mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager());
getMaxNumSequences(), getMaxBatchSize(), mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager());
mDecoderOutputBuffers.emplace_back(getMaxNumSequences(), mOperatingBeamWidth, getMaxSequenceLen(),
mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager());
}
Expand Down Expand Up @@ -1995,29 +1995,32 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(decoderStepAsync);

auto& decoderInputBuffers = mDecoderInputBuffers.at(getFusedBufferId());
auto& seqSlotLogits = decoderInputBuffers.logits;

auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId();
auto& contextRuntimeBuffers = mBuffers.at(contextBufferId);
auto const logitsIndex = (*mHandleContextLogits)(scheduledRequests.contextRequests,
contextRuntimeBuffers->numContextLogits, contextRuntimeBuffers->logits, *mDecoderBuffers, mModelConfig,
mRuntime->getBufferManager(), mRuntime->getStream(), contextRuntimeBuffers->mMedusaBuffers);
auto const logitsIndex = (*mHandleContextLogits)(decoderInputBuffers, scheduledRequests.contextRequests,
contextRuntimeBuffers->logits, contextRuntimeBuffers->numContextLogits, mModelConfig,
mRuntime->getBufferManager(), mDecoderBuffers->draftBuffers, contextRuntimeBuffers->mMedusaBuffers);

auto const genLogitsIndex = mCtxGenFusion ? logitsIndex : 0;
auto const genBufferId = mCtxGenFusion ? getFusedBufferId() : getGenerationBufferId();
auto& genRuntimeBuffers = mBuffers.at(genBufferId);
(*mHandleGenerationLogits)(genLogitsIndex, scheduledRequests.generationRequests, *mDecoderBuffers, mModelConfig,
mRuntime->getBufferManager(), genRuntimeBuffers->logits, *genRuntimeBuffers);
(*mHandleGenerationLogits)(decoderInputBuffers, scheduledRequests.generationRequests, genRuntimeBuffers->logits,
genLogitsIndex, mModelConfig, mRuntime->getBufferManager(), *genRuntimeBuffers, mDecoderBuffers->draftBuffers);

// Copy indirection output into input
// TODO: Could we avoid this by modifying batchDecoder to take a vector of tensors instead?
copyCacheIndirectionFromOutputsToInputs(scheduledRequests, genBufferId);

mLogitsPostProcessorIsApplied
= (*mLogitsPostProcessor)(scheduledRequests.contextRequests, scheduledRequests.generationRequests,
mReplicateLogitsPostProcessor, *mDecoderBuffers, mWorldConfig, *mRuntime, mLogitsPostProcessorBatched);
mReplicateLogitsPostProcessor, seqSlotLogits, mWorldConfig, *mRuntime, mLogitsPostProcessorBatched);

if (mGuidedDecoder)
{
mGuidedDecoder->execute(scheduledRequests, mRuntime->getBufferManager(), mDecoderBuffers->logits);
mGuidedDecoder->execute(scheduledRequests, mRuntime->getBufferManager(), seqSlotLogits);
}

auto const fusedBufferId = getFusedBufferId();
Expand Down
Loading