Skip to content

[Cursor] Improve Reasoning Tokens Documentation and Implementation #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cursorrules
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ If needed, you can further use the `web_scraper.py` file to scrape the web page
- When using seaborn styles in matplotlib, use 'seaborn-v0_8' instead of 'seaborn' as the style name due to recent seaborn version changes
- Use `gpt-4o` as the model name for OpenAI. It is the latest GPT model and has vision capabilities as well. `o1` is the most advanced and expensive model from OpenAI. Use it when you need to do reasoning, planning, or get blocked.
- Use `claude-3-5-sonnet-20241022` as the model name for Claude. It is the latest Claude model and has vision capabilities as well.
- When running Python scripts that import from other local modules, use `PYTHONPATH=.` to ensure Python can find the modules. For example: `PYTHONPATH=. python tools/plan_exec_llm.py` instead of just `python tools/plan_exec_llm.py`. This is especially important when using relative imports.

# Multi-Agent Scratchpad

Expand Down
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Changelog

## [Unreleased]

### Added
- Comprehensive documentation for reasoning tokens across the codebase
- Detailed test cases for token tracking with different providers
- Clear docstrings explaining provider-specific token tracking behavior

### Changed
- Updated `query_llm` function to properly handle reasoning tokens for o1 model
- Improved test coverage for token tracking across all providers
- Enhanced documentation in test files to clarify token tracking behavior

### Fixed
- Proper handling of reasoning tokens for non-o1 models (explicitly set to None)
- Token tracking tests to verify correct behavior for all providers
55 changes: 52 additions & 3 deletions tests/test_llm_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from unittest.mock import patch, MagicMock, mock_open
from tools.llm_api import create_llm_client, query_llm, load_environment
from tools.token_tracker import TokenUsage, APIResponse
from tools.token_tracker import TokenUsage, APIResponse, get_token_tracker
import os
import google.generativeai as genai
import io
Expand Down Expand Up @@ -202,18 +202,43 @@ def test_query_azure(self, mock_create_client):
)

@patch('tools.llm_api.create_llm_client')
def test_query_deepseek(self, mock_create_client):
@patch('tools.llm_api.get_token_tracker')
def test_query_deepseek(self, mock_get_tracker, mock_create_client):
"""Test querying DeepSeek API with token tracking.

DeepSeek uses OpenAI-compatible API but like most models does not support
reasoning tokens (only OpenAI's o1 model has this feature).
"""
mock_create_client.return_value = self.mock_openai_client
mock_tracker = MagicMock()
mock_get_tracker.return_value = mock_tracker

# Set up mock response with usage data
self.mock_openai_response.usage = MagicMock()
self.mock_openai_response.usage.prompt_tokens = 10
self.mock_openai_response.usage.completion_tokens = 5
self.mock_openai_response.usage.total_tokens = 15

response = query_llm("Test prompt", provider="deepseek", model="deepseek-chat")
self.assertEqual(response, "Test OpenAI response")
self.mock_openai_client.chat.completions.create.assert_called_once_with(
model="deepseek-chat",
messages=[{"role": "user", "content": [{"type": "text", "text": "Test prompt"}]}],
temperature=0.7
)
# Verify token usage tracking for OpenAI-style providers
self.assertTrue(mock_tracker.track_request.called)
api_response = mock_tracker.track_request.call_args[0][0]
# Verify reasoning_tokens is None since this is not the o1 model
self.assertIsNone(api_response.token_usage.reasoning_tokens)

@patch('tools.llm_api.create_llm_client')
def test_query_anthropic(self, mock_create_client):
"""Test querying Anthropic API.

Note: Anthropic's API has its own token tracking system that differs from OpenAI's.
It does not support reasoning tokens (which is an OpenAI o1-specific feature).
"""
mock_create_client.return_value = self.mock_anthropic_client
response = query_llm("Test prompt", provider="anthropic", model="claude-3-5-sonnet-20241022")
self.assertEqual(response, "Test Anthropic response")
Expand All @@ -222,6 +247,7 @@ def test_query_anthropic(self, mock_create_client):
max_tokens=1000,
messages=[{"role": "user", "content": [{"type": "text", "text": "Test prompt"}]}]
)
# Note: Token tracking is not yet implemented for Anthropic

@patch('tools.llm_api.create_llm_client')
def test_query_gemini(self, mock_create_client):
Expand All @@ -243,8 +269,26 @@ def test_query_with_custom_model(self, mock_create_client):
)

@patch('tools.llm_api.create_llm_client')
def test_query_o1_model(self, mock_create_client):
@patch('tools.llm_api.get_token_tracker')
def test_query_o1_model(self, mock_get_tracker, mock_create_client):
"""Test querying OpenAI's o1 model.

The o1 model is special in that it:
1. Uses a different response format
2. Has a reasoning_effort parameter
3. Is the only model that provides reasoning_tokens in its response
"""
mock_create_client.return_value = self.mock_openai_client
mock_tracker = MagicMock()
mock_get_tracker.return_value = mock_tracker

# Set up mock response with usage data including reasoning tokens
self.mock_openai_response.usage = MagicMock()
self.mock_openai_response.usage.prompt_tokens = 10
self.mock_openai_response.usage.completion_tokens = 5
self.mock_openai_response.usage.total_tokens = 15
self.mock_openai_response.usage.reasoning_tokens = 3 # o1 model provides this

response = query_llm("Test prompt", provider="openai", model="o1")
self.assertEqual(response, "Test OpenAI response")
self.mock_openai_client.chat.completions.create.assert_called_once_with(
Expand All @@ -253,6 +297,11 @@ def test_query_o1_model(self, mock_create_client):
response_format={"type": "text"},
reasoning_effort="low"
)

# Verify token usage tracking includes reasoning tokens for o1 model
self.assertTrue(mock_tracker.track_request.called)
api_response = mock_tracker.track_request.call_args[0][0]
self.assertEqual(api_response.token_usage.reasoning_tokens, 3)

@patch('tools.llm_api.create_llm_client')
def test_query_with_existing_client(self, mock_create_client):
Expand Down
76 changes: 30 additions & 46 deletions tests/test_plan_exec_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

# Add the parent directory to the Python path so we can import the module
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from tools.plan_exec_llm import load_environment, read_plan_status, read_file_content, create_llm_client, query_llm
from tools.plan_exec_llm import TokenUsage
from tools.plan_exec_llm import load_environment, read_plan_status, read_file_content, query_llm_with_plan
from tools.token_tracker import TokenUsage

class TestPlanExecLLM(unittest.TestCase):
def setUp(self):
Expand All @@ -18,9 +18,13 @@ def setUp(self):
self.original_env = dict(os.environ)
# Set test environment variables
os.environ['OPENAI_API_KEY'] = 'test_key'
os.environ['DEEPSEEK_API_KEY'] = 'test_deepseek_key'
os.environ['ANTHROPIC_API_KEY'] = 'test_anthropic_key'

self.test_env_content = """
OPENAI_API_KEY=test_key
DEEPSEEK_API_KEY=test_deepseek_key
ANTHROPIC_API_KEY=test_anthropic_key
"""
self.test_plan_content = """
# Multi-Agent Scratchpad
Expand Down Expand Up @@ -66,55 +70,35 @@ def test_read_file_content(self):
content = read_file_content('nonexistent_file.txt')
self.assertIsNone(content)

@patch('tools.plan_exec_llm.OpenAI')
def test_create_llm_client(self, mock_openai):
"""Test LLM client creation"""
mock_client = MagicMock()
mock_openai.return_value = mock_client

client = create_llm_client()
self.assertEqual(client, mock_client)
mock_openai.assert_called_once_with(api_key='test_key')

@patch('tools.plan_exec_llm.create_llm_client')
def test_query_llm(self, mock_create_client):
"""Test LLM querying"""
# Mock the OpenAI response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = "Test response"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_response.usage.total_tokens = 15
mock_response.usage.completion_tokens_details = MagicMock()
mock_response.usage.completion_tokens_details.reasoning_tokens = None

mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_response
mock_create_client.return_value = mock_client
@patch('tools.llm_api.query_llm')
def test_query_llm_with_plan(self, mock_query_llm):
"""Test LLM querying with plan context"""
# Mock the LLM response
mock_query_llm.return_value = "Test response"

# Test with various combinations of parameters
response = query_llm("Test plan", "Test prompt", "Test file content")
self.assertEqual(response, "Test response")
with patch('tools.plan_exec_llm.query_llm') as mock_plan_query_llm:
mock_plan_query_llm.return_value = "Test response"
response = query_llm_with_plan("Test plan", "Test prompt", "Test file content", provider="openai", model="gpt-4o")
self.assertEqual(response, "Test response")
mock_plan_query_llm.assert_called_with(unittest.mock.ANY, model="gpt-4o", provider="openai")

response = query_llm("Test plan", "Test prompt")
self.assertEqual(response, "Test response")
# Test with DeepSeek
response = query_llm_with_plan("Test plan", "Test prompt", provider="deepseek")
self.assertEqual(response, "Test response")
mock_plan_query_llm.assert_called_with(unittest.mock.ANY, model=None, provider="deepseek")

response = query_llm("Test plan")
self.assertEqual(response, "Test response")
# Test with Anthropic
response = query_llm_with_plan("Test plan", provider="anthropic")
self.assertEqual(response, "Test response")
mock_plan_query_llm.assert_called_with(unittest.mock.ANY, model=None, provider="anthropic")

# Verify the OpenAI client was called with correct parameters
mock_client.chat.completions.create.assert_called_with(
model="o1",
messages=[
{"role": "system", "content": ""},
{"role": "user", "content": unittest.mock.ANY}
],
response_format={"type": "text"},
reasoning_effort="low"
)
# Verify the prompt format
calls = mock_plan_query_llm.call_args_list
for call in calls:
prompt = call[0][0]
self.assertIn("Multi-Agent Scratchpad", prompt)
self.assertIn("Test plan", prompt)

if __name__ == '__main__':
unittest.main()
16 changes: 14 additions & 2 deletions tools/llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,24 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat
Args:
prompt (str): The text prompt to send
client: The LLM client instance
model (str, optional): The model to use
model (str, optional): The model to use. Special handling for OpenAI's o1 model:
- Uses different response format
- Has reasoning_effort parameter
- Is the only model that provides reasoning_tokens in its response
provider (str): The API provider to use
image_path (str, optional): Path to an image file to attach

Returns:
Optional[str]: The LLM's response or None if there was an error

Note:
Token tracking behavior varies by provider:
- OpenAI-style APIs (OpenAI, Azure, DeepSeek, Local): Full token tracking
- Anthropic: Has its own token tracking system (input/output tokens)
- Gemini: Token tracking not yet implemented

Reasoning tokens are only available when using OpenAI's o1 model.
For all other models, reasoning_tokens will be None.
"""
if client is None:
client = create_llm_client(provider)
Expand Down Expand Up @@ -187,7 +199,7 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
reasoning_tokens=response.usage.completion_tokens_details.reasoning_tokens if hasattr(response.usage, 'completion_tokens_details') else None
reasoning_tokens=response.usage.reasoning_tokens if model.lower().startswith("o") else None # Only checks if model starts with "o", e.g., o1, o1-preview, o1-mini, o3, etc. Can update this logic to specific models in the future.
)

# Calculate cost
Expand Down
67 changes: 10 additions & 57 deletions tools/plan_exec_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import argparse
import os
from pathlib import Path
from openai import OpenAI
from dotenv import load_dotenv
import sys
import time
from .token_tracker import TokenUsage, APIResponse, get_token_tracker
from tools.token_tracker import TokenUsage, APIResponse, get_token_tracker
from tools.llm_api import query_llm, create_llm_client

STATUS_FILE = '.cursorrules'

Expand Down Expand Up @@ -52,17 +52,8 @@ def read_file_content(file_path):
print(f"Error reading {file_path}: {e}", file=sys.stderr)
return None

def create_llm_client():
"""Create OpenAI client"""
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
raise ValueError("OPENAI_API_KEY not found in environment variables")
return OpenAI(api_key=api_key)

def query_llm(plan_content, user_prompt=None, file_content=None):
def query_llm_with_plan(plan_content, user_prompt=None, file_content=None, provider="openai", model=None):
"""Query the LLM with combined prompts"""
client = create_llm_client()

# Combine prompts
system_prompt = """"""

Expand Down Expand Up @@ -93,54 +84,16 @@ def query_llm(plan_content, user_prompt=None, file_content=None):
We will do the actual changes in the .cursorrules file.
"""

try:
start_time = time.time()
response = client.chat.completions.create(
model="o1",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": combined_prompt}
],
response_format={"type": "text"},
reasoning_effort="low"
)
thinking_time = time.time() - start_time

# Track token usage
token_usage = TokenUsage(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
reasoning_tokens=response.usage.completion_tokens_details.reasoning_tokens if hasattr(response.usage, 'completion_tokens_details') else None
)

# Calculate cost
cost = get_token_tracker().calculate_openai_cost(
token_usage.prompt_tokens,
token_usage.completion_tokens,
"o1"
)

# Track the request
api_response = APIResponse(
content=response.choices[0].message.content,
token_usage=token_usage,
cost=cost,
thinking_time=thinking_time,
provider="openai",
model="o1"
)
get_token_tracker().track_request(api_response)

return response.choices[0].message.content
except Exception as e:
print(f"Error querying LLM: {e}", file=sys.stderr)
return None
# Use the imported query_llm function
response = query_llm(combined_prompt, model=model, provider=provider)
return response

def main():
parser = argparse.ArgumentParser(description='Query OpenAI o1 model with project plan context')
parser = argparse.ArgumentParser(description='Query LLM with project plan context')
parser.add_argument('--prompt', type=str, help='Additional prompt to send to the LLM', required=False)
parser.add_argument('--file', type=str, help='Path to a file whose content should be included in the prompt', required=False)
parser.add_argument('--provider', choices=['openai','anthropic','gemini','local','deepseek','azure'], default='openai', help='The API provider to use')
parser.add_argument('--model', type=str, help='The model to use (default depends on provider)')
args = parser.parse_args()

# Load environment variables
Expand All @@ -157,7 +110,7 @@ def main():
sys.exit(1)

# Query LLM and output response
response = query_llm(plan_content, args.prompt, file_content)
response = query_llm_with_plan(plan_content, args.prompt, file_content, provider=args.provider, model=args.model)
if response:
print('Following is the instruction on how to revise the Multi-Agent Scratchpad section in .cursorrules:')
print('========================================================')
Expand Down
10 changes: 10 additions & 0 deletions tools/token_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

@dataclass
class TokenUsage:
"""Token usage information for an LLM API request.

Attributes:
prompt_tokens: Number of tokens in the input prompt
completion_tokens: Number of tokens in the model's response
total_tokens: Total number of tokens used (prompt + completion)
reasoning_tokens: Number of tokens used for reasoning (only available for OpenAI's o1 model)
This is a special field that's only populated when using OpenAI's o1 model.
For all other models (including other OpenAI models), this will be None.
"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
Expand Down