diff --git a/chat.go b/chat.go index c8a3e81b3..18a0b4106 100644 --- a/chat.go +++ b/chat.go @@ -280,6 +280,8 @@ type ChatCompletionRequest struct { // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` + // Add additional JSON properties to the request + ExtraBody map[string]any `json:"extra_body,omitempty"` } type StreamOptions struct { @@ -425,6 +427,7 @@ func (c *Client) CreateChatCompletion( http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request), + withExtraBody(request.ExtraBody), ) if err != nil { return diff --git a/chat_stream.go b/chat_stream.go index 80d16cc63..89c335d65 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -96,6 +96,7 @@ func (c *Client) CreateChatCompletionStream( http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request), + withExtraBody(request.ExtraBody), ) if err != nil { return nil, err diff --git a/chat_stream_test.go b/chat_stream_test.go index eabb0f3a2..c468d9452 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1021,3 +1021,486 @@ func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) } return true } + +// Helper functions for TestCreateChatCompletionStreamExtraBody to reduce complexity and improve maintainability + +func deepEqual(a, b interface{}) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + // Use reflection for deep comparison to handle maps, slices, etc. + aJSON, aErr := json.Marshal(a) + bJSON, bErr := json.Marshal(b) + if aErr != nil || bErr != nil { + return false + } + return string(aJSON) == string(bJSON) +} + +func createBaseChatStreamRequest() openai.ChatCompletionRequest { + return openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + } +} + +func validateExtraBodyFields(reqBody map[string]any, expectedExtraFields map[string]any) error { + for key, expectedValue := range expectedExtraFields { + actualValue, exists := reqBody[key] + if !exists { + return fmt.Errorf("ExtraBody field %s not found in request", key) + } + + // Handle complex types comparison safely + if !deepEqual(actualValue, expectedValue) { + return fmt.Errorf("ExtraBody field %s value mismatch: expected %v, got %v", + key, expectedValue, actualValue) + } + } + return nil +} + +func validateStandardFields(reqBody map[string]any) error { + if reqBody["model"] != "gpt-4" { + return fmt.Errorf("standard model field not found") + } + + if reqBody["stream"] != true { + return fmt.Errorf("stream field should be true") + } + return nil +} + +func parseRequestBody(r *http.Request) (map[string]any, error) { + var reqBody map[string]any + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + + err = json.Unmarshal(body, &reqBody) + if err != nil { + return nil, fmt.Errorf("failed to parse request body: %w", err) + } + return reqBody, nil +} + +func writeStreamingResponse(t *testing.T, w http.ResponseWriter) { + t.Helper() + w.Header().Set("Content-Type", "text/event-stream") + + responses := []string{ + `{"id":"test-1","object":"chat.completion.chunk","created":1598069254,` + + `"model":"gpt-4","system_fingerprint":"fp_test",` + + `"choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`, + `{"id":"test-2","object":"chat.completion.chunk","created":1598069255,` + + `"model":"gpt-4","system_fingerprint":"fp_test",` + + `"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, + } + + dataBytes := []byte{} + for _, response := range responses { + dataBytes = append(dataBytes, []byte("event: message\n")...) + dataBytes = append(dataBytes, []byte("data: "+response+"\n\n")...) + } + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + if err != nil { + t.Errorf("Failed to write response: %v", err) + } +} + +func createStreamHandler(t *testing.T, expectedExtraFields map[string]any) func( + w http.ResponseWriter, r *http.Request) { + t.Helper() + return func(w http.ResponseWriter, r *http.Request) { + if expectedExtraFields == nil { + writeStreamingResponse(t, w) + return + } + + reqBody, err := parseRequestBody(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if validationErr := validateExtraBodyFields(reqBody, expectedExtraFields); validationErr != nil { + http.Error(w, validationErr.Error(), http.StatusBadRequest) + return + } + + if standardErr := validateStandardFields(reqBody); standardErr != nil { + http.Error(w, standardErr.Error(), http.StatusBadRequest) + return + } + + writeStreamingResponse(t, w) + } +} + +func verifyStreamResponse(t *testing.T, stream *openai.ChatCompletionStream, + expectedResponses []openai.ChatCompletionStreamResponse) { + t.Helper() + + if stream == nil { + t.Fatal("Stream is nil - cannot verify response") + return + } + + defer stream.Close() + + for ix, expectedResponse := range expectedResponses { + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func testStreamExtraBodyWithParameters(t *testing.T) { + t.Helper() + client, server, teardown := setupOpenAITestServer() + defer teardown() + + expectedExtraFields := map[string]any{ + "custom_parameter": "test_value", + "additional_config": true, + "numeric_setting": float64(123), // JSON unmarshaling converts numbers to float64 + "temperature": float64(0.7), + } + + server.RegisterHandler("/v1/chat/completions", createStreamHandler(t, expectedExtraFields)) + + req := createBaseChatStreamRequest() + req.ExtraBody = map[string]any{ + "custom_parameter": "test_value", + "additional_config": true, + "numeric_setting": 123, + "temperature": 0.7, + } + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + checks.NoError(t, err, "CreateChatCompletionStream with ExtraBody should not fail") + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "test-1", + Object: "chat.completion.chunk", + Created: 1598069254, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "test-2", + Object: "chat.completion.chunk", + Created: 1598069255, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + verifyStreamResponse(t, stream, expectedResponses) +} + +func testStreamExtraBodyComplexData(t *testing.T) { + t.Helper() + client, server, teardown := setupOpenAITestServer() + defer teardown() + + expectedExtraFields := map[string]any{ + "array_param": []interface{}{"item1", "item2"}, + "unicode_text": "你好世界", + "special_chars": "!@#$%^&*()", + "nested_config": map[string]interface{}{"enabled": true, "level": float64(5)}, + "mixed_array": []interface{}{"string", float64(42), true, nil}, + "float_param": 3.14159, + "negative_int": float64(-42), + "zero_value": float64(0), + } + + server.RegisterHandler("/v1/chat/completions", createStreamHandler(t, expectedExtraFields)) + + req := createBaseChatStreamRequest() + req.ExtraBody = map[string]any{ + "array_param": []string{"item1", "item2"}, + "unicode_text": "你好世界", + "special_chars": "!@#$%^&*()", + "nested_config": map[string]any{ + "enabled": true, + "level": 5, + }, + "mixed_array": []any{"string", 42, true, nil}, + "float_param": 3.14159, + "negative_int": -42, + "zero_value": 0, + } + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + checks.NoError(t, err, "CreateChatCompletionStream with complex ExtraBody should not fail") + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "test-1", + Object: "chat.completion.chunk", + Created: 1598069254, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "test-2", + Object: "chat.completion.chunk", + Created: 1598069255, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + verifyStreamResponse(t, stream, expectedResponses) +} + +func testStreamExtraBodyEmpty(t *testing.T) { + t.Helper() + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", createStreamHandler(t, nil)) + + req := createBaseChatStreamRequest() + req.ExtraBody = map[string]any{} + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + checks.NoError(t, err, "CreateChatCompletionStream with empty ExtraBody should not fail") + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "test-1", + Object: "chat.completion.chunk", + Created: 1598069254, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "test-2", + Object: "chat.completion.chunk", + Created: 1598069255, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + verifyStreamResponse(t, stream, expectedResponses) +} + +func testStreamExtraBodyNil(t *testing.T) { + t.Helper() + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", createStreamHandler(t, nil)) + + req := createBaseChatStreamRequest() + req.ExtraBody = nil + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + checks.NoError(t, err, "CreateChatCompletionStream with nil ExtraBody should not fail") + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "test-1", + Object: "chat.completion.chunk", + Created: 1598069254, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "test-2", + Object: "chat.completion.chunk", + Created: 1598069255, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + verifyStreamResponse(t, stream, expectedResponses) +} + +func testStreamExtraBodyFieldConflicts(t *testing.T) { + t.Helper() + client, server, teardown := setupOpenAITestServer() + defer teardown() + + // Handler that verifies ExtraBody fields override standard fields + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + var reqBody map[string]any + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusInternalServerError) + return + } + + err = json.Unmarshal(body, &reqBody) + if err != nil { + http.Error(w, "Failed to parse request body", http.StatusInternalServerError) + return + } + + // Verify ExtraBody fields override standard fields + if reqBody["model"] != "overridden-model" { + msg := fmt.Sprintf("Model field should be overridden to 'overridden-model', got %v", + reqBody["model"]) + http.Error(w, msg, http.StatusBadRequest) + return + } + + if reqBody["stream"] != true { + http.Error(w, fmt.Sprintf("Stream field should remain true, got %v", reqBody["stream"]), http.StatusBadRequest) + return + } + + maxTokens, ok := reqBody["max_tokens"].(float64) + if !ok || int(maxTokens) != 9999 { + msg := fmt.Sprintf("MaxTokens field should be overridden to 9999, got %v", + reqBody["max_tokens"]) + http.Error(w, msg, http.StatusBadRequest) + return + } + + // Verify custom field from ExtraBody is present at top level + if reqBody["custom_field"] != "custom_value" { + msg := fmt.Sprintf("Custom field from ExtraBody should be 'custom_value', got %v", + reqBody["custom_field"]) + http.Error(w, msg, http.StatusBadRequest) + return + } + + // Send streaming response using the overridden model name + w.Header().Set("Content-Type", "text/event-stream") + data := `{"id":"test-1","object":"chat.completion.chunk","created":1598069254,` + + `"model":"overridden-model","system_fingerprint":"fp_test",` + + `"choices":[{"index":0,"delta":{"content":"Response"},"finish_reason":"stop"}]}` + _, writeErr := w.Write([]byte("data: " + data + "\n\ndata: [DONE]\n\n")) + if writeErr != nil { + t.Errorf("Failed to write response: %v", writeErr) + } + }) + + req := createBaseChatStreamRequest() + req.MaxTokens = 100 + req.ExtraBody = map[string]any{ + "model": "overridden-model", // this should override the standard model field + "max_tokens": 9999, // this should override the standard max_tokens field + "custom_field": "custom_value", // this is a new field + } + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + checks.NoError(t, err, "CreateChatCompletionStream with field overrides should not fail") + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "test-1", + Object: "chat.completion.chunk", + Created: 1598069254, + Model: "overridden-model", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Response", + }, + FinishReason: "stop", + }, + }, + }, + } + + verifyStreamResponse(t, stream, expectedResponses) +} + +func TestCreateChatCompletionStreamExtraBody(t *testing.T) { + t.Run("WithParameters", testStreamExtraBodyWithParameters) + t.Run("ComplexData", testStreamExtraBodyComplexData) + t.Run("EmptyExtraBody", testStreamExtraBodyEmpty) + t.Run("NilExtraBody", testStreamExtraBodyNil) + t.Run("FieldConflicts", testStreamExtraBodyFieldConflicts) +} diff --git a/chat_test.go b/chat_test.go index 514706c96..28edcbdbe 100644 --- a/chat_test.go +++ b/chat_test.go @@ -916,6 +916,429 @@ func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error return completion, nil } +// Helper functions for TestChatCompletionRequestExtraBody to reduce complexity and improve maintainability + +func createBaseChatRequest() openai.ChatCompletionRequest { + return openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + } +} + +func verifyJSONContainsFields(t *testing.T, jsonStr string, expectedFields map[string]string) { + t.Helper() + for _, expected := range expectedFields { + if !strings.Contains(jsonStr, expected) { + t.Errorf("Expected JSON to contain %s, got: %s", expected, jsonStr) + } + } +} + +func verifyExtraBodyExists(t *testing.T, extraBody map[string]any) { + t.Helper() + if extraBody == nil { + t.Fatal("ExtraBody should not be nil after unmarshaling") + } +} + +func verifyStringField(t *testing.T, extraBody map[string]any, fieldName, expected string) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + if value != expected { + t.Errorf("Expected %s to be '%s', got %v", fieldName, expected, value) + } +} + +func verifyFloatField(t *testing.T, extraBody map[string]any, fieldName string, expected float64) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + floatValue, ok := value.(float64) + if !ok { + t.Errorf("Expected %s to be float64, got type %T", fieldName, value) + return + } + if floatValue != expected { + t.Errorf("Expected %s to be %v, got %v", fieldName, expected, floatValue) + } +} + +func verifyIntField(t *testing.T, extraBody map[string]any, fieldName string, expected int) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + floatValue, ok := value.(float64) + if !ok { + t.Errorf("Expected %s to be float64, got type %T", fieldName, value) + return + } + if int(floatValue) != expected { + t.Errorf("Expected %s to be %d, got %v", fieldName, expected, int(floatValue)) + } +} + +func verifyBoolField(t *testing.T, extraBody map[string]any, fieldName string, expected bool) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + boolValue, ok := value.(bool) + if !ok { + t.Errorf("Expected %s to be bool, got type %T", fieldName, value) + return + } + if boolValue != expected { + t.Errorf("Expected %s to be %v, got %v", fieldName, expected, boolValue) + } +} + +func verifyArrayField(t *testing.T, extraBody map[string]any, fieldName string, expected []interface{}) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + arrayValue, ok := value.([]interface{}) + if !ok { + t.Errorf("Expected %s to be []interface{}, got type %T", fieldName, value) + return + } + if len(arrayValue) != len(expected) { + t.Errorf("Expected %s to have %d elements, got %d", fieldName, len(expected), len(arrayValue)) + return + } + for i, expectedVal := range expected { + if arrayValue[i] != expectedVal { + t.Errorf("%s[%d]: expected %v, got %v", fieldName, i, expectedVal, arrayValue[i]) + } + } +} + +func verifyNestedObject(t *testing.T, extraBody map[string]any, fieldName, nestedKey, expectedValue string) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + objectValue, ok := value.(map[string]interface{}) + if !ok { + t.Errorf("Expected %s to be map[string]interface{}, got type %T", fieldName, value) + return + } + nestedValue, nestedExists := objectValue[nestedKey] + if !nestedExists { + t.Errorf("%s should exist in %s", nestedKey, fieldName) + return + } + if nestedValue != expectedValue { + t.Errorf("Expected %s.%s to be '%s', got %v", fieldName, nestedKey, expectedValue, nestedValue) + } +} + +func verifyDeepNesting(t *testing.T, extraBody map[string]any) { + t.Helper() + deepNesting, ok := extraBody["deep_nesting"].(map[string]interface{}) + if !ok { + t.Error("deep_nesting should be map[string]interface{}") + return + } + level1, ok := deepNesting["level1"].(map[string]interface{}) + if !ok { + t.Error("level1 should be map[string]interface{}") + return + } + level2, ok := level1["level2"].(map[string]interface{}) + if !ok { + t.Error("level2 should be map[string]interface{}") + return + } + value, ok := level2["value"].(string) + if !ok { + t.Error("deep nested value should be string") + return + } + if value != "deep_value" { + t.Errorf("Expected deep nested value to be 'deep_value', got %v", value) + } +} + +func testExtraBodySerialization(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.ExtraBody = map[string]any{ + "custom_param": "custom_value", + "numeric_param": 42, + "boolean_param": true, + "array_param": []string{"item1", "item2"}, + "object_param": map[string]any{ + "nested_key": "nested_value", + }, + } + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with ExtraBody") + + // Verify JSON serialization + expectedFields := map[string]string{ + "extra_body": `"extra_body"`, + "custom_param": `"custom_param":"custom_value"`, + "numeric_param": `"numeric_param":42`, + "boolean_param": `"boolean_param":true`, + } + verifyJSONContainsFields(t, string(data), expectedFields) + + // Verify deserialization + var unmarshaled openai.ChatCompletionRequest + err = json.Unmarshal(data, &unmarshaled) + checks.NoError(t, err, "Failed to unmarshal request with ExtraBody") + + verifyExtraBodyExists(t, unmarshaled.ExtraBody) + verifyStringField(t, unmarshaled.ExtraBody, "custom_param", "custom_value") + verifyIntField(t, unmarshaled.ExtraBody, "numeric_param", 42) + verifyBoolField(t, unmarshaled.ExtraBody, "boolean_param", true) + verifyArrayField(t, unmarshaled.ExtraBody, "array_param", []interface{}{"item1", "item2"}) + verifyNestedObject(t, unmarshaled.ExtraBody, "object_param", "nested_key", "nested_value") +} + +func testEmptyExtraBody(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.ExtraBody = map[string]any{} + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with empty ExtraBody") + + if strings.Contains(string(data), `"extra_body"`) { + t.Error("Empty ExtraBody should be omitted from JSON") + } + + var unmarshaled openai.ChatCompletionRequest + err = json.Unmarshal(data, &unmarshaled) + checks.NoError(t, err, "Failed to unmarshal request with empty ExtraBody") + + if unmarshaled.ExtraBody != nil { + t.Error("ExtraBody should be nil when empty ExtraBody is omitted from JSON") + } +} + +func testNilExtraBody(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.ExtraBody = nil + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with nil ExtraBody") + + if strings.Contains(string(data), `"extra_body"`) { + t.Error("Nil ExtraBody should be omitted from JSON") + } + + var unmarshaled openai.ChatCompletionRequest + err = json.Unmarshal(data, &unmarshaled) + checks.NoError(t, err, "Failed to unmarshal request with nil ExtraBody") + + if unmarshaled.ExtraBody != nil { + t.Error("ExtraBody should remain nil when not present in JSON") + } +} + +func testComplexDataTypes(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.ExtraBody = map[string]any{ + "float_param": 3.14159, + "negative_int": -42, + "zero_value": 0, + "empty_string": "", + "unicode_text": "你好世界", + "special_chars": "!@#$%^&*()", + "nested_arrays": []any{[]string{"a", "b"}, []int{1, 2, 3}}, + "mixed_array": []any{"string", 42, true, nil}, + "deep_nesting": map[string]any{ + "level1": map[string]any{ + "level2": map[string]any{ + "value": "deep_value", + }, + }, + }, + } + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with complex ExtraBody") + + var unmarshaled openai.ChatCompletionRequest + err = json.Unmarshal(data, &unmarshaled) + checks.NoError(t, err, "Failed to unmarshal request with complex ExtraBody") + + verifyExtraBodyExists(t, unmarshaled.ExtraBody) + verifyFloatField(t, unmarshaled.ExtraBody, "float_param", 3.14159) + verifyIntField(t, unmarshaled.ExtraBody, "negative_int", -42) + verifyStringField(t, unmarshaled.ExtraBody, "unicode_text", "你好世界") + verifyArrayField(t, unmarshaled.ExtraBody, "mixed_array", []interface{}{"string", float64(42), true, nil}) + verifyDeepNesting(t, unmarshaled.ExtraBody) +} + +func testInvalidJSONHandling(t *testing.T) { + t.Helper() + invalidJSON := `{"model":"gpt-4","extra_body":{"invalid_json":}}` + var req openai.ChatCompletionRequest + err := json.Unmarshal([]byte(invalidJSON), &req) + if err == nil { + t.Error("Expected error when unmarshaling invalid JSON, but got nil") + } +} + +func testExtraBodyFieldConflicts(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.MaxTokens = 100 + req.ExtraBody = map[string]any{ + "model": "should-not-override", + "max_tokens": 9999, + "custom_field": "custom_value", + } + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with field conflicts in ExtraBody") + + var jsonMap map[string]any + err = json.Unmarshal(data, &jsonMap) + checks.NoError(t, err, "Failed to unmarshal JSON to generic map") + + if jsonMap["model"] != "gpt-4" { + t.Errorf("Standard model field should be 'gpt-4', got %v", jsonMap["model"]) + } + + maxTokens, ok := jsonMap["max_tokens"].(float64) + if !ok || int(maxTokens) != 100 { + t.Errorf("Standard max_tokens field should be 100, got %v", jsonMap["max_tokens"]) + } + + extraBody, ok := jsonMap["extra_body"].(map[string]interface{}) + if !ok { + t.Error("ExtraBody should be present in JSON") + return + } + customField, ok := extraBody["custom_field"].(string) + if !ok || customField != "custom_value" { + t.Errorf("Expected custom_field to be 'custom_value', got %v", customField) + } +} + +func TestChatCompletionRequestExtraBody(t *testing.T) { + t.Run("ExtraBodySerialization", testExtraBodySerialization) + t.Run("EmptyExtraBody", testEmptyExtraBody) + t.Run("NilExtraBody", testNilExtraBody) + t.Run("ComplexDataTypes", testComplexDataTypes) + t.Run("InvalidJSONHandling", testInvalidJSONHandling) + t.Run("ExtraBodyFieldConflicts", testExtraBodyFieldConflicts) +} + +func TestChatCompletionWithExtraBody(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + // Set up a handler that verifies ExtraBody fields are merged into the request body + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + var reqBody map[string]any + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusInternalServerError) + return + } + + err = json.Unmarshal(body, &reqBody) + if err != nil { + http.Error(w, "Failed to parse request body", http.StatusInternalServerError) + return + } + + // Verify that ExtraBody fields are merged at the top level + if reqBody["custom_parameter"] != "test_value" { + http.Error(w, "ExtraBody custom_parameter not found in request", http.StatusBadRequest) + return + } + if reqBody["additional_config"] != true { + http.Error(w, "ExtraBody additional_config not found in request", http.StatusBadRequest) + return + } + + // Verify standard fields are still present + if reqBody["model"] != "gpt-4" { + http.Error(w, "Standard model field not found", http.StatusBadRequest) + return + } + + // Return a mock response + res := openai.ChatCompletionResponse{ + ID: "test-id", + Object: "chat.completion", + Created: time.Now().Unix(), + Model: "gpt-4", + Choices: []openai.ChatCompletionChoice{ + { + Index: 0, + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: "Hello! I received your message with extra parameters.", + }, + FinishReason: openai.FinishReasonStop, + }, + }, + Usage: openai.Usage{ + PromptTokens: 10, + CompletionTokens: 20, + TotalTokens: 30, + }, + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(res) + if err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + }) + + // Test ChatCompletion with ExtraBody + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + ExtraBody: map[string]any{ + "custom_parameter": "test_value", + "additional_config": true, + "numeric_setting": 123, + "array_setting": []string{"option1", "option2"}, + }, + }) + + checks.NoError(t, err, "CreateChatCompletion with ExtraBody should not fail") +} + func TestFinishReason(t *testing.T) { c := &openai.ChatCompletionChoice{ FinishReason: openai.FinishReasonNull, diff --git a/client.go b/client.go index cef375348..edb66ed00 100644 --- a/client.go +++ b/client.go @@ -84,6 +84,48 @@ func withBody(body any) requestOption { } } +func withExtraBody(extraBody map[string]any) requestOption { + return func(args *requestOptions) { + if len(extraBody) == 0 { + return // No extra body to merge + } + + // Check if args.body is already a map[string]any + if bodyMap, ok := args.body.(map[string]any); ok { + // If it's already a map[string]any, directly add extraBody fields + for key, value := range extraBody { + bodyMap[key] = value + } + return + } + + // If args.body is a struct, convert it to map[string]any first + if args.body != nil { + var err error + var jsonBytes []byte + // Marshal the struct to JSON bytes + jsonBytes, err = json.Marshal(args.body) + if err != nil { + return // If marshaling fails, skip merging ExtraBody + } + + // Unmarshal JSON bytes to map[string]any + var bodyMap map[string]any + if err = json.Unmarshal(jsonBytes, &bodyMap); err != nil { + return // If unmarshaling fails, skip merging ExtraBody + } + + // Merge ExtraBody fields into the map + for key, value := range extraBody { + bodyMap[key] = value + } + + // Replace args.body with the merged map + args.body = bodyMap + } + } +} + func withContentType(contentType string) requestOption { return func(args *requestOptions) { args.header.Set("Content-Type", contentType)