-
Notifications
You must be signed in to change notification settings - Fork 269
cuda graph pool with LRU #964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
STwangyingrui
wants to merge
2
commits into
main
Choose a base branch
from
yr/cuda_graph_pool
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -346,6 +346,10 @@ def _decode( | |
) -> ModelOutput: | ||
if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch): | ||
find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) | ||
if find_graph_batch_size is None: | ||
logger.error("No suitable graph batch size found for batch_size={model_input.batch_size}, return None.") | ||
return None | ||
|
||
padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) | ||
infer_state = self._create_inferstate(padded_model_input) | ||
copy_kv_index_to_req( | ||
|
@@ -356,7 +360,9 @@ def _decode( | |
) | ||
infer_state.init_some_extra_state(self, padded_model_input.input_ids) | ||
|
||
if self.graph.need_capture(find_graph_batch_size): | ||
# Check if a graph needs to be captured. | ||
# get_graph returns None if a graph for the batch_size doesn't exist. | ||
if self.graph.get_graph(find_graph_batch_size) is None: | ||
infer_state.is_cuda_graph = True | ||
model_output: ModelOutput = self.graph.capture_decode( | ||
self._token_forward, padded_model_input.input_ids, infer_state | ||
|
@@ -497,6 +503,10 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode | |
|
||
if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch): | ||
find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size) | ||
if find_graph_batch_size is None: | ||
logger.error("No suitable graph batch size found for batch_size={origin_batch_size}, return None.") | ||
return None | ||
|
||
padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) | ||
padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size) | ||
infer_state0 = self._create_inferstate(padded_model_input0, 0) | ||
|
@@ -516,7 +526,9 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode | |
) | ||
infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) | ||
|
||
if self.graph.need_capture(find_graph_batch_size): | ||
# Check if a graph needs to be captured. | ||
# get_graph returns None if a graph for the batch_size doesn't exist. | ||
if self.graph.get_graph(find_graph_batch_size) is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the # Check if a graph needs to be captured. get_graph returns None if a graph for the batch_size doesn't exist.
if self.graph.get_graph(find_graph_batch_size) is None: |
||
infer_state0.is_cuda_graph = True | ||
infer_state1.is_cuda_graph = True | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
self.graph.get_graph(find_graph_batch_size) is None
is used to determine if a new graph needs to be captured. Consider adding a comment explaining whyget_graph
is used here instead of the previousneed_capture
function, and why checking forNone
is the appropriate way to determine if a new graph is needed.