diff --git a/go/plugins/googlegenai/gemini.go b/go/plugins/googlegenai/gemini.go index 9f8bb5ca59..86f8b85f69 100644 --- a/go/plugins/googlegenai/gemini.go +++ b/go/plugins/googlegenai/gemini.go @@ -429,6 +429,8 @@ func toGeminiRequest(input *ai.ModelRequest, cache *genai.CachedContent) (*genai // toGeminiTools translates a slice of [ai.ToolDefinition] to a slice of [genai.Tool]. func toGeminiTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) { var outTools []*genai.Tool + functions := []*genai.FunctionDeclaration{} + for _, t := range inTools { if !validToolName(t.Name) { return nil, fmt.Errorf(`invalid tool name: %q, must start with a letter or an underscore, must be alphanumeric, underscores, dots or dashes with a max length of 64 chars`, t.Name) @@ -442,8 +444,15 @@ func toGeminiTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) { Parameters: inputSchema, Description: t.Description, } - outTools = append(outTools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{fd}}) + functions = append(functions, fd) } + + if len(functions) > 0 { + outTools = append(outTools, &genai.Tool{ + FunctionDeclarations: functions, + }) + } + return outTools, nil } diff --git a/go/plugins/googlegenai/gemini_test.go b/go/plugins/googlegenai/gemini_test.go index c58f56fcfc..ca54769ce5 100644 --- a/go/plugins/googlegenai/gemini_test.go +++ b/go/plugins/googlegenai/gemini_test.go @@ -45,6 +45,9 @@ func TestConvertRequest(t *testing.T) { Temperature: genai.Ptr[float32](0.4), TopK: genai.Ptr[float32](0.1), TopP: genai.Ptr[float32](1.0), + Tools: []*genai.Tool{ + {GoogleSearch: &genai.GoogleSearch{}}, + }, ThinkingConfig: &genai.ThinkingConfig{ IncludeThoughts: false, ThinkingBudget: genai.Ptr[int32](0), @@ -151,6 +154,9 @@ func TestConvertRequest(t *testing.T) { if gcc.ThinkingConfig == nil { t.Errorf("ThinkingConfig should not be empty") } + if len(gcc.Tools) != 1 { + t.Errorf("tools should have been: 1, got: %d", len(gcc.Tools)) + } }) t.Run("use valid tools outside genkit", func(t *testing.T) { badCfg := genai.GenerateContentConfig{ diff --git a/go/plugins/googlegenai/googleai_live_test.go b/go/plugins/googlegenai/googleai_live_test.go index e0145f7771..41939755a7 100644 --- a/go/plugins/googlegenai/googleai_live_test.go +++ b/go/plugins/googlegenai/googleai_live_test.go @@ -82,6 +82,12 @@ func TestGoogleAILive(t *testing.T) { }, ) + answerOfEverythingTool := genkit.DefineTool(g, "answerOfEverything", "use this tool when the user asks for the answer of life, the universe and everything", + func(ctx *ai.ToolContext, input any) (int, error) { + return 42, nil + }, + ) + t.Run("embedder", func(t *testing.T) { res, err := genkit.Embed(ctx, g, ai.WithEmbedder(embedder), ai.WithTextDocs("yellow banana")) if err != nil { @@ -191,7 +197,43 @@ func TestGoogleAILive(t *testing.T) { t.Errorf("got %q, expecting it to contain %q", out, want) } }) - + t.Run("api side tools", func(t *testing.T) { + m := googlegenai.GoogleAIModel(g, "gemini-2.5-flash") + _, err := genkit.Generate(ctx, g, + ai.WithConfig(&genai.GenerateContentConfig{ + Tools: []*genai.Tool{ + {GoogleSearch: &genai.GoogleSearch{}}, + {CodeExecution: &genai.ToolCodeExecution{}}, + }, + }), + ai.WithModel(m), + ai.WithPrompt("When is the next lunar eclipse in US?")) + if err != nil { + t.Fatal(err) + } + }) + t.Run("api and custom tools", func(t *testing.T) { + m := googlegenai.GoogleAIModel(g, "gemini-2.5-flash") + resp, err := genkit.Generate(ctx, g, + ai.WithConfig(&genai.GenerateContentConfig{ + Tools: []*genai.Tool{ + {GoogleSearch: &genai.GoogleSearch{}}, + }, + }), + ai.WithModel(m), + ai.WithTools(gablorkenTool, answerOfEverythingTool), + ai.WithPrompt("What is the answer of life?")) + if err != nil { + t.Fatal(err) + } + // api tools should not be used when custom tools are present + if len(resp.Request.Tools) != 2 { + t.Fatalf("got %d tools, want: 2", len(resp.Request.Tools)) + } + if !strings.Contains(resp.Text(), "42") { + t.Fatalf("got %s, want: 42", resp.Text()) + } + }) t.Run("tool with json output", func(t *testing.T) { type weatherQuery struct { Location string `json:"location"` diff --git a/go/samples/prompts/main.go b/go/samples/prompts/main.go index 031903570c..6180acee48 100644 --- a/go/samples/prompts/main.go +++ b/go/samples/prompts/main.go @@ -227,6 +227,12 @@ func PromptWithTool(ctx context.Context, g *genkit.Genkit) { }, ) + answerOfEverythingTool := genkit.DefineTool(g, "answerOfEverything", "use this tool when the user asks for the answer of life, the universe and everything", + func(ctx *ai.ToolContext, input any) (int, error) { + return 42, nil + }, + ) + type Output struct { Gablorken float64 `json:"gablorken"` } @@ -236,7 +242,7 @@ func PromptWithTool(ctx context.Context, g *genkit.Genkit) { g, "PromptWithTool", ai.WithToolChoice(ai.ToolChoiceAuto), ai.WithMaxTurns(1), - ai.WithTools(gablorkenTool), + ai.WithTools(gablorkenTool, answerOfEverythingTool), ai.WithOutputType(Output{}), ai.WithPrompt("what is a gablorken of 2 over 3.5?"), )