From 97d4e9464e759547a981b8266a9157c59d6b6bad Mon Sep 17 00:00:00 2001 From: wangbaohua Date: Tue, 18 Feb 2025 23:56:38 +0700 Subject: [PATCH] feat: Add Siliconflow API support - Integrate Siliconflow API client - Add authentication and API endpoints - Implement basic API operations --- .env.example | 3 ++- tests/test_llm_api.py | 35 ++++++++++++++++++++++++++++++++++- tools/llm_api.py | 16 ++++++++++++++-- 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/.env.example b/.env.example index 12ff991..fb8290b 100644 --- a/.env.example +++ b/.env.example @@ -3,4 +3,5 @@ ANTHROPIC_API_KEY=your_anthropic_api_key_here DEEPSEEK_API_KEY=your_deepseek_api_key_here GOOGLE_API_KEY=your_google_api_key_here AZURE_OPENAI_API_KEY=your_azure_openai_api_key_here -AZURE_OPENAI_MODEL_DEPLOYMENT=gpt-4o-ms \ No newline at end of file +AZURE_OPENAI_MODEL_DEPLOYMENT=gpt-4o-ms +SILICONFLOW_API_KEY=your_siliconflow_api_key_here \ No newline at end of file diff --git a/tests/test_llm_api.py b/tests/test_llm_api.py index a5cb842..eb84571 100644 --- a/tests/test_llm_api.py +++ b/tests/test_llm_api.py @@ -112,6 +112,14 @@ def setUp(self): self.mock_gemini_model.generate_content.return_value = self.mock_gemini_response self.mock_gemini_client.GenerativeModel.return_value = self.mock_gemini_model + # Set up SiliconFlow-style response + self.mock_siliconflow_response = MagicMock() + self.mock_siliconflow_choice = MagicMock() + self.mock_siliconflow_message = MagicMock() + self.mock_siliconflow_message.content = "Test Siliconflow response" + self.mock_siliconflow_choice.message = self.mock_siliconflow_message + self.mock_siliconflow_response.choices = [self.mock_siliconflow_choice] + # Mock environment variables self.env_patcher = patch.dict('os.environ', { 'OPENAI_API_KEY': 'test-openai-key', @@ -119,7 +127,8 @@ def setUp(self): 'ANTHROPIC_API_KEY': 'test-anthropic-key', 'GOOGLE_API_KEY': 'test-google-key', 'AZURE_OPENAI_API_KEY': 'test-azure-key', - 'AZURE_OPENAI_MODEL_DEPLOYMENT': 'test-model-deployment' + 'AZURE_OPENAI_MODEL_DEPLOYMENT': 'test-model-deployment', + 'SILICONFLOW_API_KEY': 'test-siliconflow-key' }) self.env_patcher.start() @@ -167,6 +176,17 @@ def test_create_deepseek_client(self, mock_openai): ) self.assertEqual(client, self.mock_openai_client) + @unittest.skipIf(skip_llm_tests, skip_message) + @patch('tools.llm_api.OpenAI') + def test_create_siliconflow_client(self, mock_openai): + mock_openai.return_value = self.mock_openai_client + client = create_llm_client("siliconflow") + mock_openai.assert_called_once_with( + api_key='test-siliconflow-key', + base_url="https://api.siliconflow.cn/v1" + ) + self.assertEqual(client, self.mock_openai_client) + @unittest.skipIf(skip_llm_tests, skip_message) @patch('tools.llm_api.Anthropic') def test_create_anthropic_client(self, mock_anthropic): @@ -234,6 +254,19 @@ def test_query_deepseek(self, mock_create_client): temperature=0.7 ) + @unittest.skipIf(skip_llm_tests, skip_message) + @patch('tools.llm_api.create_llm_client') + def test_query_siliconflow(self, mock_create_client): + self.mock_openai_client.chat.completions.create.return_value = self.mock_siliconflow_response + mock_create_client.return_value = self.mock_openai_client + response = query_llm("Test prompt", provider="siliconflow") + self.assertEqual(response, "Test Siliconflow response") + self.mock_openai_client.chat.completions.create.assert_called_once_with( + model="deepseek-ai/DeepSeek-R1", + messages=[{"role": "user", "content": [{"type": "text", "text": "Test prompt"}]}], + temperature=0.7 + ) + @unittest.skipIf(skip_llm_tests, skip_message) @patch('tools.llm_api.create_llm_client') def test_query_anthropic(self, mock_create_client): diff --git a/tools/llm_api.py b/tools/llm_api.py index 7702083..e051b1f 100644 --- a/tools/llm_api.py +++ b/tools/llm_api.py @@ -90,6 +90,14 @@ def create_llm_client(provider="openai"): api_key=api_key, base_url="https://api.deepseek.com/v1", ) + elif provider == "siliconflow": + api_key = os.getenv('SILICONFLOW_API_KEY') + if not api_key: + raise ValueError("SILICONFLOW_API_KEY not found in environment variables") + return OpenAI( + api_key=api_key, + base_url="https://api.siliconflow.cn/v1" + ) elif provider == "anthropic": api_key = os.getenv('ANTHROPIC_API_KEY') if not api_key: @@ -137,6 +145,8 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat model = os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms') # Get from env with fallback elif provider == "deepseek": model = "deepseek-chat" + elif provider == "siliconflow": + model = "deepseek-ai/DeepSeek-R1" elif provider == "anthropic": model = "claude-3-sonnet-20240229" elif provider == "gemini": @@ -144,7 +154,7 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat elif provider == "local": model = "Qwen/Qwen2.5-32B-Instruct-AWQ" - if provider in ["openai", "local", "deepseek", "azure"]: + if provider in ["openai", "local", "deepseek", "azure", "siliconflow"]: messages = [{"role": "user", "content": []}] # Add text content @@ -232,7 +242,7 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat def main(): parser = argparse.ArgumentParser(description='Query an LLM with a prompt') parser.add_argument('--prompt', type=str, help='The prompt to send to the LLM', required=True) - parser.add_argument('--provider', choices=['openai','anthropic','gemini','local','deepseek','azure'], default='openai', help='The API provider to use') + parser.add_argument('--provider', choices=['openai','anthropic','gemini','local','deepseek','azure','siliconflow'], default='openai', help='The API provider to use') parser.add_argument('--model', type=str, help='The model to use (default depends on provider)') parser.add_argument('--image', type=str, help='Path to an image file to attach to the prompt') args = parser.parse_args() @@ -242,6 +252,8 @@ def main(): args.model = "gpt-4o" elif args.provider == "deepseek": args.model = "deepseek-chat" + elif args.provider == "siliconflow": + args.model = "deepseek-ai/DeepSeek-R1" elif args.provider == 'anthropic': args.model = "claude-3-5-sonnet-20241022" elif args.provider == 'gemini':