Skip to content

Commit 32b14ba

Browse files
authored
[Refactor][Frontend] Keep all logic about reasoning into one class (#14428)
Signed-off-by: Ce Gao <[email protected]>
1 parent 2d9045f commit 32b14ba

18 files changed

+171
-200
lines changed
Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,74 +3,92 @@
33
import pytest
44
from transformers import AutoTokenizer
55

6-
from tests.entrypoints.openai.reasoning_parsers.utils import (
7-
run_reasoning_extraction)
8-
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
9-
ReasoningParserManager)
6+
from tests.reasoning.utils import run_reasoning_extraction
7+
from vllm.reasoning import ReasoningParser, ReasoningParserManager
108

119
parser_name = "deepseek_r1"
1210
start_token = "<think>"
1311
end_token = "</think>"
1412

13+
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
14+
15+
16+
@pytest.fixture(scope="module")
17+
def deepseek_r1_qwen_tokenizer():
18+
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
19+
20+
1521
SIMPLE_REASONING = {
1622
"output": "This is a reasoning section</think>This is the rest",
1723
"reasoning_content": "This is a reasoning section",
1824
"content": "This is the rest",
25+
"is_reasoning_end": True,
1926
}
2027
COMPLETE_REASONING = {
2128
"output": "This is a reasoning section</think>",
2229
"reasoning_content": "This is a reasoning section",
2330
"content": None,
31+
"is_reasoning_end": True,
2432
}
2533
NO_CONTENT = {
2634
"output": "This is content",
2735
"reasoning_content": "This is content",
2836
"content": None,
37+
"is_reasoning_end": False,
2938
}
3039
NO_REASONING_STREAMING = {
3140
"output": "This is a reasoning section",
3241
"reasoning_content": "This is a reasoning section",
3342
"content": None,
43+
"is_reasoning_end": False,
3444
}
3545
MULTIPLE_LINES = {
3646
"output": "This\nThat</think>This is the rest\nThat",
3747
"reasoning_content": "This\nThat",
3848
"content": "This is the rest\nThat",
49+
"is_reasoning_end": True,
3950
}
4051
SHORTEST_REASONING_NO_STREAMING = {
4152
"output": "</think>This is the rest",
4253
"reasoning_content": "",
4354
"content": "This is the rest",
55+
"is_reasoning_end": True,
4456
}
4557
SHORTEST_REASONING = {
4658
"output": "</think>This is the rest",
4759
"reasoning_content": None,
4860
"content": "This is the rest",
61+
"is_reasoning_end": True,
4962
}
5063
REASONING_WITH_THINK = {
5164
"output": "<think>This is a reasoning section</think>This is the rest",
5265
"reasoning_content": "This is a reasoning section",
5366
"content": "This is the rest",
67+
"is_reasoning_end": True,
5468
}
5569
COMPLETE_REASONING_WITH_THINK = {
5670
"output": "<think>This is a reasoning section</think>",
5771
"reasoning_content": "This is a reasoning section",
5872
"content": None,
73+
"is_reasoning_end": True,
5974
}
6075
MULTIPLE_LINES_WITH_THINK = {
6176
"output": "<think>This\nThat</think>This is the rest\nThat",
6277
"reasoning_content": "This\nThat",
6378
"content": "This is the rest\nThat",
79+
"is_reasoning_end": True,
6480
}
6581
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
6682
"output": "</think>This is the rest",
6783
"reasoning_content": "",
6884
"content": "This is the rest",
85+
"is_reasoning_end": True,
6986
}
7087
SHORTEST_REASONING_WITH_THINK = {
7188
"output": "</think>This is the rest",
7289
"reasoning_content": None,
7390
"content": "This is the rest",
91+
"is_reasoning_end": True,
7492
}
7593

7694
TEST_CASES = [
@@ -166,27 +184,39 @@
166184
),
167185
]
168186

169-
# Global tokenizer initialization to avoid repeated loading
170-
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
171-
tokenizer.add_tokens([start_token, end_token])
172-
173187

174188
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
175189
def test_reasoning(
176190
streaming: bool,
177191
param_dict: dict,
192+
deepseek_r1_qwen_tokenizer,
178193
):
179-
output = tokenizer.tokenize(param_dict["output"])
194+
output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"])
180195
# decode everything to tokens
181196
output_tokens: list[str] = [
182-
tokenizer.convert_tokens_to_string([token]) for token in output
197+
deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token])
198+
for token in output
183199
]
184200
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
185-
parser_name)(tokenizer)
201+
parser_name)(deepseek_r1_qwen_tokenizer)
186202

187203
reasoning, content = run_reasoning_extraction(parser,
188204
output_tokens,
189205
streaming=streaming)
190206

191207
assert reasoning == param_dict["reasoning_content"]
192208
assert content == param_dict["content"]
209+
210+
# Test is_reasoning_end
211+
output_ids = deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(output)
212+
is_reasoning_end = parser.is_reasoning_end(output_ids)
213+
assert is_reasoning_end == param_dict["is_reasoning_end"]
214+
215+
# Test extract_content
216+
if param_dict["content"] is not None:
217+
content = parser.extract_content_ids(output_ids)
218+
assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(
219+
deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"]))
220+
else:
221+
content = parser.extract_content_ids(output)
222+
assert content == []
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
import pytest
33
from transformers import AutoTokenizer
44

5-
from tests.entrypoints.openai.reasoning_parsers.utils import (
6-
DeltaMessage, run_reasoning_extraction)
7-
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
8-
ReasoningParserManager)
5+
from tests.reasoning.utils import DeltaMessage, run_reasoning_extraction
6+
from vllm.reasoning import ReasoningParser, ReasoningParserManager
97

108
parser_name = "granite"
119
START_REASONING = "Here is my thought process:"

tests/entrypoints/openai/reasoning_parsers/utils.py renamed to tests/reasoning/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
66
DeltaMessage)
7-
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParser
7+
from vllm.reasoning import ReasoningParser
88

99

1010
class StreamingReasoningReconstructor:

vllm/engine/arg_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.logger import init_logger
2424
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
2525
from vllm.plugins import load_general_plugins
26+
from vllm.reasoning import ReasoningParserManager
2627
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
2728
from vllm.transformers_utils.utils import check_gguf_file
2829
from vllm.usage.usage_lib import UsageContext
@@ -1119,7 +1120,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
11191120
parser.add_argument(
11201121
"--reasoning-parser",
11211122
type=str,
1122-
choices=["deepseek_r1", "granite"],
1123+
choices=list(ReasoningParserManager.reasoning_parsers),
11231124
default=None,
11241125
help=
11251126
"Select the reasoning parser depending on the model that you're "

vllm/engine/llm_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,8 +2080,9 @@ def _build_logits_processors(
20802080
guided_decoding.backend = guided_decoding.backend or \
20812081
self.decoding_config.guided_decoding_backend
20822082

2083-
logger.debug("Reasoning backend: %s",
2084-
self.decoding_config.reasoning_backend)
2083+
if self.decoding_config.reasoning_backend is not None:
2084+
logger.debug("Building with reasoning backend %s",
2085+
self.decoding_config.reasoning_backend)
20852086

20862087
processor = get_local_guided_decoding_logits_processor(
20872088
guided_params=guided_decoding,

vllm/entrypoints/openai/api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
TranscriptionRequest,
6969
TranscriptionResponse,
7070
UnloadLoRAAdapterRequest)
71-
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
7271
# yapf: enable
7372
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
7473
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
@@ -85,6 +84,7 @@
8584
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
8685
from vllm.entrypoints.utils import load_aware_call, with_cancellation
8786
from vllm.logger import init_logger
87+
from vllm.reasoning import ReasoningParserManager
8888
from vllm.transformers_utils.config import (
8989
maybe_register_config_serialize_by_value)
9090
from vllm.transformers_utils.tokenizer import MistralTokenizer

vllm/entrypoints/openai/serving_chat.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
2424
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
2525
RequestResponseMetadata, ToolCall, UsageInfo)
26-
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
27-
ReasoningParserManager)
2826
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
2927
clamp_prompt_logprobs)
3028
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
@@ -33,6 +31,7 @@
3331
MistralToolCall)
3432
from vllm.logger import init_logger
3533
from vllm.outputs import CompletionOutput, RequestOutput
34+
from vllm.reasoning import ReasoningParser, ReasoningParserManager
3635
from vllm.sampling_params import BeamSearchParams, SamplingParams
3736
from vllm.sequence import Logprob
3837
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer

vllm/model_executor/guided_decoding/__init__.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from typing import TYPE_CHECKING
66

77
from vllm.logger import init_logger
8-
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
98
from vllm.model_executor.guided_decoding.utils import (
109
convert_lark_to_gbnf, grammar_is_likely_lark,
1110
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
11+
from vllm.reasoning import ReasoningParserManager
1212

1313
if TYPE_CHECKING:
1414
from transformers import PreTrainedTokenizer
@@ -107,7 +107,11 @@ async def get_guided_decoding_logits_processor(
107107
model_config: ModelConfig,
108108
reasoning_backend: str | None = None) -> LogitsProcessor | None:
109109

110-
reasoner = get_reasoner(tokenizer, reasoning_backend)
110+
reasoner = None
111+
if reasoning_backend is not None:
112+
reasoner_class = ReasoningParserManager.get_reasoning_parser(
113+
reasoning_backend)
114+
reasoner = reasoner_class(tokenizer)
111115

112116
guided_params = maybe_backend_fallback(guided_params)
113117

@@ -146,8 +150,11 @@ def get_local_guided_decoding_logits_processor(
146150
reasoning_backend: str | None = None) -> LogitsProcessor | None:
147151
guided_params = maybe_backend_fallback(guided_params)
148152

149-
# Get the reasoner if needed, it will be None if reasoning_
150-
reasoner = get_reasoner(tokenizer, reasoning_backend)
153+
reasoner = None
154+
if reasoning_backend is not None:
155+
reasoner_class = ReasoningParserManager.get_reasoning_parser(
156+
reasoning_backend)
157+
reasoner = reasoner_class(tokenizer)
151158

152159
# CFG grammar not supported by LMFE, so we use outlines instead
153160
if guided_params.backend_name == 'outlines':

vllm/model_executor/guided_decoding/outlines_decoding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
1414
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
15-
from vllm.model_executor.guided_decoding.reasoner import Reasoner
15+
from vllm.reasoning import ReasoningParser
1616
from vllm.sampling_params import GuidedDecodingParams
1717

1818

@@ -61,7 +61,7 @@ class GuidedDecodingMode(Enum):
6161
async def get_outlines_guided_decoding_logits_processor(
6262
guided_params: GuidedDecodingParams,
6363
tokenizer: PreTrainedTokenizerBase,
64-
reasoner: Optional[Reasoner],
64+
reasoner: Optional[ReasoningParser],
6565
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
6666
None]:
6767
"""
@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor(
9292
def get_local_outlines_guided_decoding_logits_processor(
9393
guided_params: GuidedDecodingParams,
9494
tokenizer: PreTrainedTokenizerBase,
95-
reasoner: Optional[Reasoner],
95+
reasoner: Optional[ReasoningParser],
9696
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
9797
None]:
9898
"""
@@ -141,7 +141,7 @@ def _get_logits_processor(
141141
tokenizer: PreTrainedTokenizerBase,
142142
mode: GuidedDecodingMode,
143143
whitespace_pattern: Union[str, None],
144-
reasoner: Optional[Reasoner],
144+
reasoner: Optional[ReasoningParser],
145145
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
146146
if mode == GuidedDecodingMode.JSON:
147147
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,

vllm/model_executor/guided_decoding/outlines_logits_processors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434

3535
import vllm.envs as envs
3636
from vllm.logger import init_logger
37-
from vllm.model_executor.guided_decoding.reasoner import Reasoner
3837
from vllm.platforms import current_platform
38+
from vllm.reasoning import ReasoningParser
3939

4040
logger = init_logger(__name__)
4141

@@ -49,9 +49,9 @@
4949

5050
class BaseLogitsProcessor:
5151

52-
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
52+
def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
5353
self._guide: Guide = guide
54-
self._reasoner: Optional[Reasoner] = reasoner
54+
self._reasoner: Optional[ReasoningParser] = reasoner
5555
# CFGState is used for the FSM state for CFGGuide
5656
self._fsm_state: DefaultDict[int, Union[int,
5757
CFGState]] = defaultdict(int)
@@ -69,7 +69,7 @@ def __call__(self, input_ids: List[int],
6969
# Remove the reasoning tokens from the input_ids
7070
# We need this because our implementation relies on the
7171
# hash of the input_ids to store the FSM state.
72-
input_ids = self._reasoner.extract_content(input_ids)
72+
input_ids = self._reasoner.extract_content_ids(input_ids)
7373

7474
seq_id = hash(tuple(input_ids))
7575

@@ -142,7 +142,7 @@ def __init__(
142142
self,
143143
regex_string: str,
144144
tokenizer: PreTrainedTokenizerBase,
145-
reasoner: Optional[Reasoner],
145+
reasoner: Optional[ReasoningParser],
146146
):
147147
"""Compile the FSM that drives the regex-structured generation.
148148
@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
163163
def __init__(self, schema: Union[str, Dict, BaseModel],
164164
tokenizer: PreTrainedTokenizerBase,
165165
whitespace_pattern: Union[str, None],
166-
reasoner: Optional[Reasoner]):
166+
reasoner: Optional[ReasoningParser]):
167167
"""Compile the FSM that drives the JSON-guided generation.
168168
169169
Parameters
@@ -203,7 +203,7 @@ def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
203203
return CFGGuide(cfg, tokenizer)
204204

205205
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
206-
reasoner: Optional[Reasoner]):
206+
reasoner: Optional[ReasoningParser]):
207207
"""Compile the FSM that drives the context free grammar generation.
208208
209209
Parameters

vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)