diff --git a/.gitignore b/.gitignore index 7dbdfee..5084528 100644 --- a/.gitignore +++ b/.gitignore @@ -135,6 +135,16 @@ venv/ ENV/ env.bak/ venv.bak/ +pyenv/ +*.env + +# VS Code +.vscode/ +.history/ +*.code-workspace + +# Specstory +.specstory/ # Spyder project settings .spyderproject diff --git a/README.md b/README.md index 1038d38..94169cc 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Grok3 API -Grok3 is cool, smart, and useful, but there is no official API available. This is an **unofficial Python client** for interacting with the Grok 3 API. It leverages cookies from browser requests to authenticate and access the API endpoints. +Grok3 is cool, smart, and useful, but there is no official API available. This is an **unofficial Python client** for interacting with the Grok 3 API. It leverages cookies from browser requests to authenticate and access the API endpoints. The API also provides OpenAI-compatible endpoints for easy integration with existing applications. --- @@ -49,9 +49,11 @@ Example cookie string from a curl command: ### 4. Use the Client +#### 4.1 Direct Client Usage + Pass the extracted cookie values to the GrokClient and send a message: -``` +```python from grok_client import GrokClient # Your cookie values @@ -71,6 +73,64 @@ response = client.send_message("write a poem") print(response) ``` +#### 4.2 OpenAI-Compatible API Server + +The package includes an OpenAI-compatible API server that allows you to use Grok with any OpenAI-compatible client library or application. + +##### Start the Server + +1. Create a `.env` file in the project directory using the provided `.env.example` template: +```bash +cp grok_client/.env.example .env +``` + +2. Update the `.env` file with your Grok cookie values: +```env +GROK_SSO=your_sso_cookie +GROK_SSO_RW=your_sso_rw_cookie +# Optional configurations +API_HOST=127.0.0.1 +API_PORT=8000 +MODEL_NAME=grok-3 +``` + +3. Start the API server: +```bash +uvicorn grok_client.server:app --reload --host 0.0.0.0 --port 8000 +``` + +##### Use with OpenAI Python Client + +```python +from openai import OpenAI + +# Initialize client pointing to local server +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="dummy-key" # Any non-empty string will work +) + +# Create a chat completion +response = client.chat.completions.create( + model="grok-3", # Model name can be configured in .env + messages=[ + {"role": "user", "content": "Hello, how can you help me?"} + ] +) + +print(response.choices[0].message.content) +``` + +##### Interactive Chat Script + +The package includes an interactive chat script that uses the OpenAI-compatible endpoint: + +```bash +python grok_client/interactive.py +``` + +This provides a command-line interface for chatting with Grok using the OpenAI-compatible API. + ### 5. Optional: Add Memory with Mem0 If you want Grok to remember conversations, you can integrate it with Mem0. Mem0 provides a memory layer for AI applications. diff --git a/grok_client/.env.example b/grok_client/.env.example new file mode 100644 index 0000000..9dd9afb --- /dev/null +++ b/grok_client/.env.example @@ -0,0 +1,14 @@ +# Grok API Configuration + +# API Server Configuration +API_HOST=127.0.0.1 +API_PORT=8000 + +# Grok Model Configuration +MODEL_NAME=grok-3 + +# Authentication Cookies +# Replace these with your actual Grok cookies +# You can obtain these from your browser after logging into Grok +GROK_SSO=your_sso_cookie_value_here +GROK_SSO_RW=your_sso_rw_cookie_value_here \ No newline at end of file diff --git a/grok_client/__init__.py b/grok_client/__init__.py index 1f82267..065bfa9 100644 --- a/grok_client/__init__.py +++ b/grok_client/__init__.py @@ -1,4 +1,33 @@ +""" +Grok Client Package +=================== + +This package provides client utilities for interacting with a Grok API. +It includes: +- `GrokClient`: A low-level client for sending messages directly to the Grok service. +- `GrokOpenAIClient`: An OpenAI-compatible client that wraps the Grok API, + allowing it to be used with tools and libraries designed for the OpenAI API structure. +- Custom error classes for more specific error handling. +- A FastAPI server (`server.py`) that exposes an OpenAI-compatible API endpoint + backed by the Grok service. + +The main components intended for direct use are typically `GrokClient` or +`GrokOpenAIClient`. +""" from .client import GrokClient +from .grok_openai_client import GrokOpenAIClient +from .errors import GrokClientError, GrokAPIError, AuthenticationError, NetworkError, ConfigurationError __version__ = "0.1.0" -__all__ = ['GrokClient'] \ No newline at end of file +"""The version of the grok_client package.""" + +__all__ = [ + 'GrokClient', + 'GrokOpenAIClient', + 'GrokClientError', + 'GrokAPIError', + 'AuthenticationError', + 'NetworkError', + 'ConfigurationError' +] +"""Publicly exposed names from the grok_client package.""" \ No newline at end of file diff --git a/grok_client/client.py b/grok_client/client.py index 1f2510f..6ab28aa 100644 --- a/grok_client/client.py +++ b/grok_client/client.py @@ -1,22 +1,119 @@ import requests import json +import time +""" +Grok Direct Client Module +========================= + +This module provides the `GrokClient` class, a low-level client for direct interaction +with the Grok API. It handles request preparation, sending messages, and processing +responses, including error handling specific to Grok API interactions. + +It is intended for use cases where direct control over the Grok API is needed, +as opposed to using an OpenAI-compatible interface. +""" +import logging +import re +import os +import requests # Import requests for type hinting session and response +from typing import Dict, List, Any, Union, Optional # Import necessary types + +from .errors import GrokAPIError, AuthenticationError, NetworkError, ConfigurationError, GrokClientError + +# Set up logging +_DEFAULT_LOG_LEVEL: str = "INFO" +_ALLOWED_LOG_LEVELS: List[str] = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] + +def get_log_level_from_env() -> str: + """ + Retrieves and validates the log level from the GROK_LOG_LEVEL environment variable. + + If the environment variable is not set or contains an invalid value, + a warning is logged, and the default log level ("INFO") is returned. + + Returns: + str: The validated log level string (e.g., "DEBUG", "INFO"). + """ + env_log_level: str = os.environ.get("GROK_LOG_LEVEL", _DEFAULT_LOG_LEVEL).upper() + if env_log_level not in _ALLOWED_LOG_LEVELS: + # Log a warning using a temporary basic config if the level is invalid, then default + # This initial basicConfig is for this specific warning message only. + # The main basicConfig later will use the determined (or default) LOG_LEVEL. + logging.basicConfig(level=logging.WARNING) + logging.warning(f"Invalid GROK_LOG_LEVEL '{env_log_level}'. Defaulting to '{_DEFAULT_LOG_LEVEL}'.") + return _DEFAULT_LOG_LEVEL + return env_log_level + +LOG_LEVEL: str = get_log_level_from_env() +# The main basicConfig for the logger, using the determined LOG_LEVEL. +# Note: If the logger was already configured by the above warning, this might reconfigure it +# or be ignored depending on Python's logging internals. Ideally, only configure once. +# To ensure single configuration, we can check if root logger has handlers. +if not logging.root.handlers: + logging.basicConfig(level=LOG_LEVEL) +else: # If already configured (e.g. by the warning message), just set level + logging.getLogger().setLevel(LOG_LEVEL) + +logger = logging.getLogger(__name__) class GrokClient: - def __init__(self, cookies): + """ + A client for interacting directly with the Grok API. + + This client handles the necessary authentication (via cookies) and request + formatting to send messages to Grok and retrieve responses. It is designed + for lower-level access to the Grok service. + + Attributes: + base_url (str): The base URL for the Grok API chat endpoint. + cookies (Dict[str, str]): Cookies used for authentication, must include 'sso' and 'sso-rw'. + headers (Dict[str, str]): Standard headers sent with each request. + """ + base_url: str + cookies: Dict[str, str] + headers: Dict[str, str] + + def __init__(self, cookies: Dict[str, str]) -> None: """ - Initialize the Grok client with cookie values + Initializes the GrokClient. Args: - cookies (dict): Dictionary containing cookie values - - x-anonuserid - - x-challenge - - x-signature - - sso - - sso-rw + cookies (Dict[str, str]): A dictionary containing authentication cookies. + Must include 'sso' and 'sso-rw' keys with their respective token values. + + Raises: + AuthenticationError: If 'sso' or 'sso-rw' cookies are missing or empty. """ self.base_url = "https://grok.com/rest/app-chat/conversations/new" - self.cookies = cookies - self.headers = { + + # Convert cookie string to dict if needed + if isinstance(cookies.get('Cookie'), str): + cookie_dict = {} + for cookie in cookies.get('Cookie', '').split(';'): + if cookie.strip(): + name, value = cookie.strip().split('=', 1) + cookie_dict[name.strip()] = value.strip() + self.cookies = cookie_dict + else: + self.cookies = cookies # type: ignore # Assuming cookies can be other types initially + + # Ensure self.cookies is Dict[str, str] after processing + if not isinstance(self.cookies, dict) or \ + not all(isinstance(k, str) and isinstance(v, str) for k, v in self.cookies.items()): + # This case should ideally be handled by stricter input validation or type checking earlier + # For now, if it's not a Dict[str, str] after processing, it's an issue. + # We'll assume the conversion logic above makes it Dict[str, str] or raises. + # If not, the following checks might fail or behave unexpectedly. + # To be robust, one might add: self.cookies = {} if not isinstance(self.cookies, dict) else self.cookies + pass + + + if not self.cookies.get('sso') or not self.cookies.get('sso-rw'): + raise AuthenticationError("Missing required SSO cookies (sso, sso-rw) for GrokClient initialization.") + + logger.debug(f"Using cookies: {self.cookies}") + + self.headers = { # type: Dict[str, str] "accept": "*/*", "accept-language": "en-GB,en;q=0.9", "content-type": "application/json", @@ -25,28 +122,37 @@ def __init__(self, cookies): "referer": "https://grok.com/", "sec-ch-ua": '"Not/A)Brand";v="8", "Chromium";v="126", "Brave";v="126"', "sec-ch-ua-mobile": "?0", - "sec-ch-ua-platform": '"macOS"', + "sec-ch-ua-platform": '"Windows"', "sec-fetch-dest": "empty", "sec-fetch-mode": "cors", "sec-fetch-site": "same-origin", "sec-gpc": "1", - "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36" + "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36" } + logger.debug(f"Initialized GrokClient with headers: {self.headers}") + + def _prepare_payload(self, message: str) -> Dict[str, Any]: + """ + Prepares the JSON payload for sending a message to the Grok API. + + Args: + message (str): The user's input message. - def _prepare_payload(self, message): - """Prepare the default payload with the user's message""" - return { + Returns: + Dict[str, Any]: A dictionary representing the JSON payload. + """ + payload: Dict[str, Any] = { "temporary": False, - "modelName": "grok-3", + "modelName": "grok-3", # Or make this configurable via __init__ "message": message, "fileAttachments": [], "imageAttachments": [], "disableSearch": False, - "enableImageGeneration": True, + "enableImageGeneration": False, "returnImageBytes": False, "returnRawGrokInXaiRequest": False, - "enableImageStreaming": True, - "imageGenerationCount": 2, + "enableImageStreaming": False, + "imageGenerationCount": 0, "forceConcise": False, "toolOverrides": {}, "enableSideBySide": True, @@ -56,44 +162,210 @@ def _prepare_payload(self, message): "deepsearchPreset": "", "isReasoning": False } + logger.debug(f"Prepared payload: {payload}") + return payload + + def _clean_json_response(self, response: str) -> str: + """ + Cleans up a JSON-like string response from Grok. + + This method attempts to remove common markdown code block delimiters (```json ... ```) + and then tries to parse and re-serialize the JSON to ensure it's well-formed. + If the input string contains nested JSON strings (e.g., in 'response' or + 'function_call.arguments'), it attempts to extract and format those. + + Args: + response (str): The raw string response which may contain JSON. + + Returns: + str: A cleaned-up JSON string, or the original string if it's not valid JSON + or if extraction logic doesn't apply. + """ + # Remove markdown code blocks + cleaned_response: str = re.sub(r'```json\s*', '', response) + cleaned_response = re.sub(r'```\s*$', '', cleaned_response) + + try: + # Try to parse as JSON + json_data: Any = json.loads(cleaned_response) + + # If the response has a nested response or function_call, extract it + if isinstance(json_data, dict): + if "response" in json_data and isinstance(json_data["response"], (dict, list, str)): # Check type of inner response + json_data = json_data["response"] + elif "function_call" in json_data and \ + isinstance(json_data.get("function_call"), dict) and \ + "arguments" in json_data["function_call"] and \ + isinstance(json_data["function_call"]["arguments"], str): + # Attempt to parse arguments if they are a string-encoded JSON + try: + json_data = json.loads(json_data["function_call"]["arguments"]) + except json.JSONDecodeError: + # If arguments are not valid JSON, keep them as string or handle as error + # For now, we assume it should be parsable if it's a function call argument. + # If not, it might remain a string within the function_call structure. + # This part might need more specific error handling or schema validation. + pass # Keep json_data as the function_call dict if arguments are not JSON string. + + return json.dumps(json_data, indent=2) + except json.JSONDecodeError: + # If it's not valid JSON after cleaning, return the cleaned string as is. + return cleaned_response - def send_message(self, message): + def send_message(self, message: str, stream_callback: Optional[Any] = None) -> str: """ - Send a message to Grok and collect the streaming response + Sends a message to the Grok API and processes the response. + + This method can handle both streaming and non-streaming responses. + If `stream_callback` is provided, it's expected that this method (or the underlying + request logic) will invoke the callback with chunks of data from the stream. + The current implementation primarily collects a full response, but the + `stream_callback` argument is kept for compatibility with potential future + true streaming implementations or for how the server part uses it. Args: - message (str): The user's input message + message (str): The user's input message. + stream_callback (Optional[Any]): An optional callback function to handle + streaming data. (Note: Current client implementation collects full response; + true client-side streaming via this callback is not fully implemented here). Returns: - str: The complete response from Grok + str: The complete, cleaned response text from Grok. If the response is streamed + and `stream_callback` is used, the return value might be the final + accumulated response or an empty string if all data is handled by callback. + Currently, it returns the `xai_generated_text` from the last valid packet. + + Raises: + AuthenticationError: If the API returns a 401 or 403 status code. + GrokAPIError: For other 4xx/5xx API errors or if the API response + indicates an error (e.g., within the JSON payload). + NetworkError: For network-level issues like connection errors or timeouts. """ - payload = self._prepare_payload(message) - response = requests.post( - self.base_url, - headers=self.headers, - cookies=self.cookies, - json=payload, - stream=True - ) - - full_response = "" - - for line in response.iter_lines(): - if line: - decoded_line = line.decode('utf-8') - try: - json_data = json.loads(decoded_line) - result = json_data.get("result", {}) - response_data = result.get("response", {}) - - if "modelResponse" in response_data: - return response_data["modelResponse"]["message"] - - token = response_data.get("token", "") - if token: - full_response += token - - except json.JSONDecodeError: - continue - - return full_response.strip() \ No newline at end of file + try: + logger.debug(f"Sending message to Grok: {message}") + payload: Dict[str, Any] = self._prepare_payload(message) + + logger.debug(f"Making POST request to {self.base_url}") + logger.debug(f"Using cookies: {self.cookies}") + + session: requests.Session = requests.Session() + for cookie_name, cookie_value in self.cookies.items(): + session.cookies.set(cookie_name, cookie_value) + + response: requests.Response = session.post( + self.base_url, # type: ignore # self.base_url is str + headers=self.headers, # type: ignore # self.headers is Dict[str, str] + json=payload, + stream=True # Always stream to inspect line by line + ) + + logger.debug(f"Response status code: {response.status_code}") + + if response.status_code == 401 or response.status_code == 403: + raise AuthenticationError(f"Authentication failed with Grok API: {response.status_code} - {response.text}") + if response.status_code >= 400: + raise GrokAPIError(f"Grok API request failed: {response.status_code} - {response.text}") + + full_response_accumulator: str = "" + last_processed_response_text: Optional[str] = None # To store the text from the last meaningful packet + + logger.debug("Processing response stream...") + for line in response.iter_lines(): # type: bytes + if line: + decoded_line: str = "" + try: + decoded_line = line.decode('utf-8') + logger.debug(f"Received line: {decoded_line}") + + # Assuming each line is a separate JSON object, as per typical SSE-like streams + json_data: Any = json.loads(decoded_line) + logger.debug(f"Parsed JSON from line: {json_data}") + + # Error checking within the JSON payload itself + if isinstance(json_data, dict) and "error" in json_data: + error_msg: str = str(json_data.get("error", "Unknown API error in JSON payload")) + logger.error(f"API Error in response payload: {error_msg}") + raise GrokAPIError(f"Grok API returned an error in payload: {error_msg}") + + # Data extraction logic (this part is highly dependent on Grok's actual streaming format) + # The existing code seems to expect a structure like: + # {"result": {"response": {"modelResponse": {"message": "..."}}}} or + # {"result": {"response": {"token": "..."}}} + # And also a top-level "xai_generated_text" in some cases (often in final packets). + + current_text_piece: Optional[str] = None + if isinstance(json_data, dict): + if "xai_generated_text" in json_data: # Often in final non-streaming style packet + current_text_piece = json_data["xai_generated_text"] + + result_data = json_data.get("result") + if isinstance(result_data, dict): + response_data = result_data.get("response") + if isinstance(response_data, dict): + if "modelResponse" in response_data and \ + isinstance(response_data["modelResponse"], dict) and \ + "message" in response_data["modelResponse"]: + # This seems like a full message override + current_text_piece = response_data["modelResponse"]["message"] + elif "token" in response_data: # Streaming token + token_text = response_data.get("token") + if isinstance(token_text, str): + full_response_accumulator += token_text + # For streaming, current_text_piece might be just the token + # or we rely on full_response_accumulator. + # The old logic used `last_response` for the whole json_data. + # Let's assume for now the goal is to get any text. + if not current_text_piece: # Prioritize xai_generated_text or full message + current_text_piece = token_text + + + if stream_callback and current_text_piece is not None: + # If there's a callback, send the current piece of data. + # The callback might receive individual tokens or larger chunks. + # The structure of `json_data` or `current_text_piece` should align with callback needs. + # The original server.py implies callback gets the whole json_data dict. + stream_callback(json_data) # Pass the whole JSON object to callback + + if current_text_piece is not None: + last_processed_response_text = current_text_piece # Keep track of text from last meaningful packet + + except json.JSONDecodeError: + logger.warning(f"Failed to decode JSON from line: {decoded_line if decoded_line else line.decode('utf-8', errors='ignore')}") + # Decide if this is fatal or skippable. For now, skip. + continue + except GrokAPIError: # Re-raise API errors found in payload + raise + except Exception as e: # Catch other errors during line processing + logger.error(f"Error processing response line: {decoded_line} - Error: {e}", exc_info=True) + raise GrokAPIError(f"Error processing Grok API response line: {str(e)}") + + # After iterating through all lines: + # Determine what to return. The old logic returned based on 'last_response'. + # If streaming tokens were accumulated: + if full_response_accumulator: + logger.debug(f"Returning accumulated streaming response: {full_response_accumulator.strip()}") + return self._clean_json_response(full_response_accumulator.strip()) + + # If individual text pieces were processed (e.g., from xai_generated_text or modelResponse.message): + if last_processed_response_text is not None: + logger.debug(f"Returning last processed text: {last_processed_response_text.strip()}") + return self._clean_json_response(last_processed_response_text.strip()) + + # If we reach here, no meaningful data was extracted or accumulated. + logger.error("No valid/extractable content received from Grok API stream.") + raise GrokAPIError("No valid/extractable content message received from Grok API stream.") + + except requests.exceptions.ConnectionError as e: + logger.error(f"Connection error while contacting Grok API: {str(e)}", exc_info=True) + raise NetworkError(f"Connection error while contacting Grok API: {str(e)}") + except requests.exceptions.Timeout as e: + logger.error(f"Timeout while contacting Grok API: {str(e)}", exc_info=True) + raise NetworkError(f"Timeout while contacting Grok API: {str(e)}") + except requests.exceptions.RequestException as e: # Other request-related errors + logger.error(f"Network request to Grok API failed: {str(e)}", exc_info=True) + raise NetworkError(f"Network request to Grok API failed: {str(e)}") + except GrokClientError: # Re-raise already handled custom Grok client errors + raise + except Exception as e: # Catch-all for unexpected errors + logger.error(f"An unexpected error occurred while processing Grok response: {str(e)}", exc_info=True) + raise GrokAPIError(f"An unexpected error occurred while processing Grok response: {str(e)}") \ No newline at end of file diff --git a/grok_client/errors.py b/grok_client/errors.py new file mode 100644 index 0000000..60fb8ac --- /dev/null +++ b/grok_client/errors.py @@ -0,0 +1,58 @@ +""" +Custom exception classes for the Grok client. + +This module defines a hierarchy of custom exceptions to provide more specific error +information when interacting with the Grok API or the client library itself. +""" + +class GrokClientError(Exception): + """ + Base class for all custom exceptions raised by the Grok client library. + + This exception can be used to catch any error originating from the Grok client, + allowing for a general way to handle client-specific issues. + """ + pass + +class GrokAPIError(GrokClientError): + """ + Raised when the Grok API returns an error response. + + This typically indicates a problem on the server-side or an issue with the + request that the API itself has identified (e.g., invalid parameters, + rate limits, server errors within Grok's infrastructure). The error message + will usually contain details from the API's error response. + """ + pass + +class AuthenticationError(GrokClientError): + """ + Raised for authentication-related failures. + + This can occur if SSO tokens (cookies) are missing, invalid, or expired, + preventing successful authentication with the Grok API. It may also indicate + permission issues if the authenticated user does not have access to a + requested resource or model. + """ + pass + +class ConfigurationError(GrokClientError): + """ + Raised for errors related to client or environment configuration. + + This includes issues such as missing or invalid essential environment variables + (e.g., API host, port if not using defaults), incorrect cookie paths, or other + setup problems that prevent the client from initializing or operating correctly. + """ + pass + +class NetworkError(GrokClientError): + """ + Raised for network-related issues encountered while communicating with the Grok API. + + This can include problems like connection timeouts, DNS resolution failures, + or other issues preventing the client from reaching the Grok API servers. + It generally suggests a problem with the network connection between the client + and the API endpoint. + """ + pass diff --git a/grok_client/grok_openai_client.py b/grok_client/grok_openai_client.py new file mode 100644 index 0000000..35d874f --- /dev/null +++ b/grok_client/grok_openai_client.py @@ -0,0 +1,98 @@ +import os +import openai +import logging + +logger = logging.getLogger(__name__) + +class GrokOpenAIClient: + """ + A client for interacting with a Grok API that mimics the OpenAI API structure. + """ + def __init__(self, api_host=None, api_port=None, model_name=None, sso_token=None, sso_rw_token=None, load_from_env=True): + """ + Initializes the GrokOpenAIClient. + + Args: + api_host (str, optional): The API host. Defaults to '127.0.0.1'. + api_port (str, optional): The API port. Defaults to '8000'. + model_name (str, optional): The model name. Defaults to 'grok-3'. + sso_token (str, optional): The SSO token. + sso_rw_token (str, optional): The SSO RW token. + load_from_env (bool, optional): Whether to load parameters from environment variables. Defaults to True. + + Raises: + ValueError: If essential cookie information (sso_token, sso_rw_token) is missing. + """ + if load_from_env: + api_host = api_host or os.getenv('GROK_API_HOST', '127.0.0.1') + api_port = api_port or os.getenv('GROK_API_PORT', '8000') + model_name = model_name or os.getenv('GROK_MODEL_NAME', 'grok-3') + sso_token = sso_token or os.getenv('GROK_SSO_TOKEN') + sso_rw_token = sso_rw_token or os.getenv('GROK_SSO_RW_TOKEN') + + if not sso_token or not sso_rw_token: + raise ValueError("SSO token and SSO RW token are required.") + + self.model_name = model_name + base_url = f"http://{api_host}:{api_port}/v1" + + self.client = openai.OpenAI( + base_url=base_url, + api_key="dummy_key", # OpenAI client requires an API key, but Grok uses SSO tokens + default_headers={ + "Cookie": f"sso={sso_token}; sso-rw={sso_rw_token}" + } + ) + logger.info(f"GrokOpenAIClient initialized with model: {self.model_name}, API: {base_url}") + + def chat_completion(self, messages, stream=False, temperature=0.7, response_format=None, max_tokens=None, **kwargs): + """ + Creates a chat completion using the configured Grok model. + + Args: + messages (list): A list of message objects, similar to OpenAI's API. + stream (bool, optional): Whether to stream the response. Defaults to False. + temperature (float, optional): Sampling temperature. Defaults to 0.7. + response_format (dict, optional): The response format. Defaults to None. + max_tokens (int, optional): The maximum number of tokens to generate. Defaults to None. + **kwargs: Additional keyword arguments to pass to the OpenAI client. + + Returns: + The response from the OpenAI client, which could be a streaming object or a completion object. + """ + params = { + "model": self.model_name, + "messages": messages, + "stream": stream, + "temperature": temperature, + } + if response_format is not None: + params["response_format"] = response_format + if max_tokens is not None: + params["max_tokens"] = max_tokens + + params.update(kwargs) # Add any other common parameters + + logger.debug(f"Sending chat completion request with params: {params}") + return self.client.chat.completions.create(**params) + + def process_streaming_response(self, stream_response): + """ + Processes a streaming response, printing each chunk's content and returning the full response. + + Args: + stream_response: The streaming response object from chat_completion with stream=True. + + Returns: + str: The accumulated full response string. + """ + full_response = [] + for chunk in stream_response: + if hasattr(chunk, 'choices') and chunk.choices: + delta = chunk.choices[0].delta + if hasattr(delta, 'content') and delta.content: + content_piece = delta.content + print(content_piece, end='', flush=True) + full_response.append(content_piece) + print() # Newline after the stream is complete + return "".join(full_response) diff --git a/grok_client/interactive_chat.py b/grok_client/interactive_chat.py new file mode 100644 index 0000000..eecec16 --- /dev/null +++ b/grok_client/interactive_chat.py @@ -0,0 +1,188 @@ +import os +import sys +import logging +import argparse +from dotenv import load_dotenv +from .grok_openai_client import GrokOpenAIClient + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def parse_arguments(): + """ + Parse command line arguments for the interactive chat application. + + Returns: + argparse.Namespace: The parsed command line arguments. + """ + parser = argparse.ArgumentParser(description='Interactive chat with Grok API using OpenAI-compatible interface') + parser.add_argument('--host', help='API host (default: from .env or 127.0.0.1)') + parser.add_argument('--port', help='API port (default: from .env or 8000)') + parser.add_argument('--model', help='Model name (default: from .env or grok-3)') + parser.add_argument('--sso', help='SSO token (default: from .env)') + parser.add_argument('--sso-rw', help='SSO-RW token (default: from .env)') + parser.add_argument('--json', action='store_true', help='Request responses in JSON format') + parser.add_argument('--system', help='Custom system message') + parser.add_argument('--temperature', type=float, default=1.0, help='Temperature for response generation (default: 1.0)') + + return parser.parse_args() + +def setup_client(args): + """ + Set up the Grok OpenAI client using command line arguments or environment variables. + + Args: + args (argparse.Namespace): The parsed command line arguments. + + Returns: + GrokOpenAIClient: The initialized client. + """ + try: + # Initialize client with args or environment variables + client = GrokOpenAIClient( + api_host=args.host, + api_port=args.port, + model_name=args.model, + sso_token=args.sso, + sso_rw_token=args.sso_rw, + load_from_env=True # Always try to load from env first + ) + + return client + except ValueError as e: + logger.error(f"Error initializing client: {e}") + logger.info("Make sure you have the required environment variables set in .env file or provided as arguments.") + logger.info("Required variables: GROK_SSO, GROK_SSO_RW") + logger.info("Optional variables: API_HOST, API_PORT, MODEL_NAME") + sys.exit(1) + +def interactive_chat(): + """ + Run an interactive chat session with Grok using the OpenAI-compatible interface. + """ + # Parse command line arguments + args = parse_arguments() + + # Set up client + client = setup_client(args) + model_name = client.model_name + + # Set up system message + system_message = args.system + if args.json and not system_message: + system_message = "You are a helpful assistant that always responds in valid JSON format." + elif not system_message: + system_message = "You are a helpful assistant." + + print(f"\n===== Grok Interactive Chat ({model_name}) =====") + print("Type 'exit', 'quit', or Ctrl+C to end the conversation.") + print("Type 'clear' to start a new conversation.") + print("Type '/help' to see available commands.") + print("==============================\n") + + # Initialize conversation history + conversation = [] + if system_message: + conversation.append({"role": "system", "content": system_message}) + + try: + while True: + # Get user input + user_input = input("\nYou: ") + + # Check for exit commands + if user_input.lower() in ['exit', 'quit']: + print("\nExiting chat. Goodbye!") + break + + # Check for clear command + if user_input.lower() == 'clear': + conversation = [] + if system_message: + conversation.append({"role": "system", "content": system_message}) + print("\nConversation history cleared.") + continue + + # Check for help command + if user_input.lower() == '/help': + print("\nAvailable commands:") + print(" exit, quit - Exit the chat") + print(" clear - Clear conversation history") + print(" /help - Show this help message") + print(" /json - Toggle JSON response format") + print(" /temp - Set temperature (0.0-2.0)") + print(" /system - Set system message") + continue + + # Check for JSON toggle command + if user_input.lower() == '/json': + args.json = not args.json + print(f"\nJSON response format: {'enabled' if args.json else 'disabled'}") + continue + + # Check for temperature command + if user_input.lower().startswith('/temp '): + try: + new_temp = float(user_input.split(' ', 1)[1]) + if 0.0 <= new_temp <= 2.0: + args.temperature = new_temp + print(f"\nTemperature set to: {args.temperature}") + else: + print("\nTemperature must be between 0.0 and 2.0") + except (ValueError, IndexError): + print("\nInvalid temperature value. Format: /temp 0.7") + continue + + # Check for system message command + if user_input.lower().startswith('/system '): + system_message = user_input.split(' ', 1)[1] + # Update the system message in the conversation + conversation = [msg for msg in conversation if msg["role"] != "system"] + conversation.insert(0, {"role": "system", "content": system_message}) + print(f"\nSystem message updated.") + continue + + # Add user message to conversation + conversation.append({"role": "user", "content": user_input}) + + try: + # Send request to Grok API + print("\nGrok: ", end="", flush=True) + + # Prepare request parameters + params = { + "messages": conversation, + "stream": True, + "temperature": args.temperature + } + + # Add JSON format if requested + if args.json: + params["response_format"] = {"type": "json_object"} + + # Use streaming for a more interactive experience + stream = client.chat_completion(**params) + + # Process the streaming response + full_response = client.process_streaming_response(stream) + + # Add assistant response to conversation history + conversation.append({"role": "assistant", "content": full_response}) + + except Exception as e: + logger.error(f"Error: {str(e)}") + print(f"\nAn error occurred: {str(e)}") + + except KeyboardInterrupt: + print("\n\nExiting chat. Goodbye!") + +def main(): + # Load environment variables + load_dotenv() + + # Run interactive chat + interactive_chat() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/grok_client/server.py b/grok_client/server.py new file mode 100644 index 0000000..ced1996 --- /dev/null +++ b/grok_client/server.py @@ -0,0 +1,338 @@ +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse +from typing import List, Optional, Dict, Any, Union +from pydantic import BaseModel, Field +from .client import GrokClient +from .errors import GrokAPIError, AuthenticationError, NetworkError, ConfigurationError, GrokClientError +import json +import time +import logging +import os +import uuid + +# Set up logging +_DEFAULT_LOG_LEVEL = "INFO" +_ALLOWED_LOG_LEVELS = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] + +def get_log_level_from_env(): + env_log_level = os.environ.get("GROK_LOG_LEVEL", _DEFAULT_LOG_LEVEL).upper() + if env_log_level not in _ALLOWED_LOG_LEVELS: + logging.basicConfig(level=logging.WARNING) # Temp for this message + logging.warning(f"Invalid GROK_LOG_LEVEL '{env_log_level}'. Defaulting to '{_DEFAULT_LOG_LEVEL}'.") + logging.basicConfig(level=_DEFAULT_LOG_LEVEL) # Reset + return _DEFAULT_LOG_LEVEL + return env_log_level + +LOG_LEVEL = get_log_level_from_env() +logging.basicConfig(level=LOG_LEVEL, format='%(asctime)s - %(levelname)s - %(name)s - [%(request_id)s] - %(message)s', defaults={'request_id': 'N/A'}) +logger = logging.getLogger(__name__) + +app = FastAPI() + +async def add_request_id_middleware(request: Request, call_next): + request_id = request.headers.get("X-Request-ID") + if not request_id: + request_id = str(uuid.uuid4()) + + request.state.request_id = request_id + + response = await call_next(request) + response.headers["X-Request-ID"] = request_id + return response + +app.middleware("http")(add_request_id_middleware) + +# Enable CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +class ChatMessage(BaseModel): + role: str + content: str + function_call: Optional[Dict[str, Any]] = None + +class FunctionCall(BaseModel): + name: str + arguments: str + +class Function(BaseModel): + name: str + description: str + parameters: Dict[str, Any] + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + stream: Optional[bool] = False + temperature: Optional[float] = 1.0 + max_tokens: Optional[int] = None + functions: Optional[List[Function]] = None + function_call: Optional[Union[str, Dict[str, str]]] = None + response_format: Optional[Dict[str, str]] = None + +class ChatCompletionChoice(BaseModel): + index: int = 0 + message: ChatMessage + finish_reason: str = "stop" + +class ChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int + model: str + choices: List[ChatCompletionChoice] + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + function_call: Optional[Dict[str, Any]] = None + +class ChatCompletionChunk(BaseModel): + id: str + object: str = "chat.completion.chunk" + created: int + model: str + choices: List[Dict[str, Any]] + +class GrokAPI: + def __init__(self, cookies: Dict[str, str]): + self.client = GrokClient(cookies) + + def _prepare_system_message(self, request: ChatCompletionRequest) -> str: + # Default to simple responses unless specifically asked for structured output + system_content = "You are a helpful assistant. Provide direct, simple answers to questions." + + # Add function calling instructions if needed + if request.functions: + system_content = "You are a helpful assistant that provides structured data." + system_content += f" Available functions: {[f.name for f in request.functions]}" + system_content += f" Function schemas: {json.dumps([f.dict() for f in request.functions])}" + + # Add JSON format instructions if needed + elif request_data.response_format and request_data.response_format.get("type") == "json_object": + system_content = "You are a helpful assistant that always responds in valid JSON format." + + return system_content + + def stream_chat(self, request_data: ChatCompletionRequest, request_id: str): + try: + # Prepare the conversation context + system_msg = self._prepare_system_message(request_data) + conversation = f"system: {system_msg}\n" + "\n".join([f"{msg.role}: {msg.content}" for msg in request_data.messages]) + + logger.debug(f"Sending conversation to Grok: {conversation}", extra={'request_id': request_id}) + + # Get streaming response from Grok + response_stream = self.client.send_message(conversation) # send_message itself logs with its own context + logger.debug(f"Got response stream from Grok: {response_stream}", extra={'request_id': request_id}) + + # Stream the response in OpenAI format + for token in response_stream.split(): + chunk = ChatCompletionChunk( + id="chatcmpl-" + str(int(time.time())), + created=int(time.time()), + model="grok-3", + choices=[{ + "index": 0, + "delta": {"content": token + " "}, + "finish_reason": None + }] + ) + yield f"data: {json.dumps(chunk.dict())}\n\n" + + # Send the final chunk + final_chunk = ChatCompletionChunk( + id="chatcmpl-final", + created=int(time.time()), + model="grok-3", # Or use request_data.model + choices=[{ + "index": 0, + "delta": {}, + "finish_reason": "stop" + }] + ) + yield f"data: {json.dumps(final_chunk.dict())}\n\n" + # yield "data: [DONE]\n\n" # This is handled by finally + except AuthenticationError as e: + logger.error(f"AuthenticationError during streaming: {str(e)}", exc_info=True, extra={'request_id': request_id}) + yield f"data: {json.dumps({'error': {'message': str(e), 'type': 'AuthenticationError', 'code': 401}})}\n\n" + except ConfigurationError as e: + logger.error(f"ConfigurationError during streaming: {str(e)}", exc_info=True, extra={'request_id': request_id}) + yield f"data: {json.dumps({'error': {'message': str(e), 'type': 'ConfigurationError', 'code': 400}})}\n\n" + except NetworkError as e: + logger.error(f"NetworkError during streaming: {str(e)}", exc_info=True, extra={'request_id': request_id}) + yield f"data: {json.dumps({'error': {'message': str(e), 'type': 'NetworkError', 'code': 502}})}\n\n" + except GrokAPIError as e: + logger.error(f"GrokAPIError during streaming: {str(e)}", exc_info=True, extra={'request_id': request_id}) + yield f"data: {json.dumps({'error': {'message': str(e), 'type': 'GrokAPIError', 'code': 502}})}\n\n" + except GrokClientError as e: + logger.error(f"GrokClientError during streaming: {str(e)}", exc_info=True, extra={'request_id': request_id}) + yield f"data: {json.dumps({'error': {'message': str(e), 'type': 'GrokClientError', 'code': 500}})}\n\n" + except Exception as e: + logger.error(f"Unexpected error in stream_chat: {str(e)}", exc_info=True, extra={'request_id': request_id}) + yield f"data: {json.dumps({'error': {'message': f'An unexpected error occurred during streaming: {str(e)}', 'type': 'ServerError', 'code': 500}})}\n\n" + finally: + logger.info("Finished streaming chat attempt.", extra={'request_id': request_id}) + yield "data: [DONE]\n\n" + +@app.get("/v1/models") +async def list_models(): + return { + "data": [ + { + "id": "grok-3", + "object": "model", + "created": int(time.time()), + "owned_by": "xai", + "permission": [], + "root": "grok-3", + "parent": None + } + ] + } + +@app.post("/v1/chat/completions") +async def create_chat_completion(raw_request: Request): + try: + request_id = raw_request.state.request_id # Get request_id from middleware + # Get request body + body = await raw_request.json() + logger.debug(f"Received request body: {body}", extra={'request_id': request_id}) + + # Parse request into ChatCompletionRequest + request_data = ChatCompletionRequest(**body) # Renamed to request_data + + # Get cookies from request headers + headers = dict(raw_request.headers) + logger.debug(f"Received headers: {headers}", extra={'request_id': request_id}) + + cookies = {'Cookie': headers.get('cookie', '')} if headers.get('cookie') else {} + logger.debug(f"Extracted cookies: {cookies}", extra={'request_id': request_id}) + + if not cookies: + # Log before raising HTTPException, as HTTPException might not be logged with request_id by default handler + logger.warning("No authentication cookies provided.", extra={'request_id': request_id}) + raise HTTPException(status_code=401, detail="No authentication cookies provided") + + # Initialize Grok API with cookies + grok = GrokAPI(cookies) # GrokClient init logs internally, request_id not directly available there + + if request_data.stream: + return StreamingResponse( + grok.stream_chat(request_data, request_id), # Pass request_id + media_type="text/event-stream" + ) + + # For non-streaming response + system_msg = grok._prepare_system_message(request_data) + conversation = f"system: {system_msg}\n" + "\n".join([f"{msg.role}: {msg.content}" for msg in request_data.messages]) + logger.debug(f"Sending conversation to Grok: {conversation}", extra={'request_id': request_id}) + + response = grok.client.send_message(conversation) # send_message logs internally + logger.debug(f"Received response from Grok: {response}", extra={'request_id': request_id}) + + if not response: + logger.error("Empty response from Grok API", extra={'request_id': request_id}) + raise HTTPException(status_code=500, detail="Empty response from Grok API") + + # Handle function calling + if request_data.functions and request_data.function_call: + try: + # Try to parse the response as JSON + parsed_response = json.loads(response) + + # Get the function name from the request + function_name = request_data.function_call.get("name", request_data.functions[0].name) if isinstance(request_data.function_call, dict) else request_data.functions[0].name + + message = ChatMessage( + role="assistant", + content="", + function_call={ + "name": function_name, + "arguments": json.dumps(parsed_response) + } + ) + except json.JSONDecodeError: + logger.warning(f"Function call response is not valid JSON. Original response: {response}", extra={'request_id': request_id}) + # If response is not valid JSON, wrap it in a basic structure + function_name = request_data.function_call.get("name", request_data.functions[0].name) if isinstance(request_data.function_call, dict) else request_data.functions[0].name + message = ChatMessage( + role="assistant", + content="", + function_call={ + "name": function_name, + "arguments": json.dumps({"result": response}) + } + ) + else: + # Regular response or JSON format + if request_data.response_format and request_data.response_format.get("type") == "json_object": + try: + # Ensure the response is valid JSON + json.loads(response) # Validate + message = ChatMessage( + role="assistant", + content=response + ) + except json.JSONDecodeError: + logger.warning(f"JSON format requested, but response is not valid JSON. Original response: {response}", extra={'request_id': request_id}) + # If not valid JSON, wrap it in a JSON structure + message = ChatMessage( + role="assistant", + content=json.dumps({"response": response}) # Wrap to make it JSON + ) + else: + message = ChatMessage( + role="assistant", + content=response + ) + + # Create response object + chat_response = ChatCompletionResponse( + id=f"chatcmpl-{str(int(time.time()))}", # Consider using request_id or part of it for traceability + created=int(time.time()), + model=request_data.model, + choices=[ChatCompletionChoice( + message=message, + finish_reason="stop" + )] + ) + + logger.debug(f"Sending response: {chat_response.dict()}", extra={'request_id': request_id}) + return chat_response + + except AuthenticationError as e: + logger.error(f"AuthenticationError in chat completion: {str(e)}", exc_info=True, extra={'request_id': raw_request.state.request_id if hasattr(raw_request.state, 'request_id') else 'N/A'}) + return JSONResponse(status_code=401, content={"error": {"message": str(e), "type": "AuthenticationError"}}) + except ConfigurationError as e: + logger.error(f"ConfigurationError in chat completion: {str(e)}", exc_info=True, extra={'request_id': raw_request.state.request_id if hasattr(raw_request.state, 'request_id') else 'N/A'}) + return JSONResponse(status_code=400, content={"error": {"message": str(e), "type": "ConfigurationError"}}) + except NetworkError as e: + logger.error(f"NetworkError in chat completion: {str(e)}", exc_info=True, extra={'request_id': raw_request.state.request_id if hasattr(raw_request.state, 'request_id') else 'N/A'}) + return JSONResponse(status_code=502, content={"error": {"message": str(e), "type": "NetworkError"}}) + except GrokAPIError as e: + logger.error(f"GrokAPIError in chat completion: {str(e)}", exc_info=True, extra={'request_id': raw_request.state.request_id if hasattr(raw_request.state, 'request_id') else 'N/A'}) + return JSONResponse(status_code=502, content={"error": {"message": str(e), "type": "GrokAPIError"}}) + except GrokClientError as e: + logger.error(f"GrokClientError in chat completion: {str(e)}", exc_info=True, extra={'request_id': raw_request.state.request_id if hasattr(raw_request.state, 'request_id') else 'N/A'}) + return JSONResponse(status_code=500, content={"error": {"message": str(e), "type": "GrokClientError"}}) + except HTTPException as http_exc: + # If we want to log HTTPExceptions with request_id, we need to catch, log, and re-raise or return response + logger.error(f"HTTPException in chat completion: Status {http_exc.status_code}, Detail {http_exc.detail}", exc_info=True, extra={'request_id': raw_request.state.request_id if hasattr(raw_request.state, 'request_id') else 'N/A'}) + raise http_exc # Re-raise to let FastAPI handle the response + except Exception as e: + # Ensure request_id is available for logging if possible + req_id = 'N/A' + if hasattr(raw_request, 'state') and hasattr(raw_request.state, 'request_id'): + req_id = raw_request.state.request_id + logger.error(f"Unexpected error in create_chat_completion: {str(e)}", exc_info=True, extra={'request_id': req_id}) + return JSONResponse( + status_code=500, + content={"error": {"message": f"An unexpected server error occurred: {str(e)}", "type": "ServerError"}} + ) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6b11265 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +requests>=2.31.0 +fastapi==0.109.2 +uvicorn==0.27.1 +python-dotenv>=1.0.0 +pip install dotenv + +python -m venv venv +.\venv\Scripts\activate \ No newline at end of file